diff --git a/README.md b/README.md index 5f6e047..3e8fd01 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -## Real-time Hand-Detection using Neural Networks (SSD) on Tensorflow. +## Real-time Hand-Detection using Neural Networks (SSD) on Tensorflow 2. -This repo documents steps and scripts used to train a hand detector using Tensorflow (Object Detection API). As with any DNN based task, the most expensive (and riskiest) part of the process has to do with finding or creating the right (annotated) dataset. I was interested mainly in detecting hands on a table (egocentric view point). I experimented first with the [Oxford Hands Dataset](http://www.robots.ox.ac.uk/~vgg/data/hands/) (the results were not good). I then tried the [Egohands Dataset](http://vision.soic.indiana.edu/projects/egohands/) which was a much better fit to my requirements. +This repo documents steps and scripts used to train a hand detector using Tensorflow 2 (Object Detection API). As with any DNN based task, the most expensive (and riskiest) part of the process has to do with finding or creating the right (annotated) dataset. I was interested mainly in detecting hands on a table (egocentric view point). I experimented first with the [Oxford Hands Dataset](http://www.robots.ox.ac.uk/~vgg/data/hands/) (the results were not good). I then tried the [Egohands Dataset](http://vision.soic.indiana.edu/projects/egohands/) which was a much better fit to my requirements. The goal of this repo/post is to demonstrate how neural networks can be applied to the (hard) problem of tracking hands (egocentric and other views). Better still, provide code that can be adapted to other uses cases. @@ -23,7 +23,7 @@ Both examples above were run on a macbook pro **CPU** (i7, 2.5GHz, 16GB). Some f | 16 | 320 * 240 | Macbook pro (i7, 2.5GHz, 16GB) | Run while visualizing results (image above) | | 11 | 640 * 480 | Macbook pro (i7, 2.5GHz, 16GB) | Run while visualizing results (image above) | -> Note: The code in this repo is written and tested with Tensorflow `1.4.0-rc0`. Using a different version may result in [some errors](https://github.com/tensorflow/models/issues/1581). +> Note: The code in this repo is written and tested with Tensorflow `2.3.1`. Using a different version may result in [some errors](https://github.com/tensorflow/models/issues/1581). You may need to [generate your own frozen model](https://pythonprogramming.net/testing-custom-object-detector-tensorflow-object-detection-api-tutorial/?completed=/training-custom-objects-tensorflow-object-detection-api-tutorial/) graph using the [model checkpoints](model-checkpoint) in the repo to fit your TF version. The tensorflow object detection repo has a [python file for exporting a checkpoint to frozen graph here](https://github.com/tensorflow/models/blob/master/research/object_detection/export_inference_graph.py). You can copy it to the current directory and use it as follows diff --git a/detect_multi_threaded.py b/detect_multi_threaded.py index ef2a879..92883ad 100644 --- a/detect_multi_threaded.py +++ b/detect_multi_threaded.py @@ -18,7 +18,7 @@ def worker(input_q, output_q, cap_params, frame_processed): print(">> loading frozen model for worker") detection_graph, sess = detector_utils.load_inference_graph() - sess = tf.Session(graph=detection_graph) + sess = tf.compat.v1.Session(graph=detection_graph) while True: #print("> ===== in worker loop, frame ", frame_processed) frame = input_q.get() diff --git a/detect_multi_threaded.py.old b/detect_multi_threaded.py.old new file mode 100644 index 0000000..ef2a879 --- /dev/null +++ b/detect_multi_threaded.py.old @@ -0,0 +1,171 @@ +from utils import detector_utils as detector_utils +import cv2 +import tensorflow as tf +import multiprocessing +from multiprocessing import Queue, Pool +import time +from utils.detector_utils import WebcamVideoStream +import datetime +import argparse + +frame_processed = 0 +score_thresh = 0.2 + +# Create a worker thread that loads graph and +# does detection on images in an input queue and puts it on an output queue + + +def worker(input_q, output_q, cap_params, frame_processed): + print(">> loading frozen model for worker") + detection_graph, sess = detector_utils.load_inference_graph() + sess = tf.Session(graph=detection_graph) + while True: + #print("> ===== in worker loop, frame ", frame_processed) + frame = input_q.get() + if (frame is not None): + # Actual detection. Variable boxes contains the bounding box cordinates for hands detected, + # while scores contains the confidence for each of these boxes. + # Hint: If len(boxes) > 1 , you may assume you have found atleast one hand (within your score threshold) + + boxes, scores = detector_utils.detect_objects( + frame, detection_graph, sess) + # draw bounding boxes + detector_utils.draw_box_on_image( + cap_params['num_hands_detect'], cap_params["score_thresh"], + scores, boxes, cap_params['im_width'], cap_params['im_height'], + frame) + # add frame annotated with bounding box to queue + output_q.put(frame) + frame_processed += 1 + else: + output_q.put(frame) + sess.close() + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument( + '-src', + '--source', + dest='video_source', + type=int, + default=0, + help='Device index of the camera.') + parser.add_argument( + '-nhands', + '--num_hands', + dest='num_hands', + type=int, + default=2, + help='Max number of hands to detect.') + parser.add_argument( + '-fps', + '--fps', + dest='fps', + type=int, + default=1, + help='Show FPS on detection/display visualization') + parser.add_argument( + '-wd', + '--width', + dest='width', + type=int, + default=300, + help='Width of the frames in the video stream.') + parser.add_argument( + '-ht', + '--height', + dest='height', + type=int, + default=200, + help='Height of the frames in the video stream.') + parser.add_argument( + '-ds', + '--display', + dest='display', + type=int, + default=1, + help='Display the detected images using OpenCV. This reduces FPS') + parser.add_argument( + '-num-w', + '--num-workers', + dest='num_workers', + type=int, + default=4, + help='Number of workers.') + parser.add_argument( + '-q-size', + '--queue-size', + dest='queue_size', + type=int, + default=5, + help='Size of the queue.') + args = parser.parse_args() + + input_q = Queue(maxsize=args.queue_size) + output_q = Queue(maxsize=args.queue_size) + + video_capture = WebcamVideoStream( + src=args.video_source, width=args.width, height=args.height).start() + + cap_params = {} + frame_processed = 0 + cap_params['im_width'], cap_params['im_height'] = video_capture.size() + cap_params['score_thresh'] = score_thresh + + # max number of hands we want to detect/track + cap_params['num_hands_detect'] = args.num_hands + + print(cap_params, args) + + # spin up workers to paralleize detection. + pool = Pool(args.num_workers, worker, + (input_q, output_q, cap_params, frame_processed)) + + start_time = datetime.datetime.now() + num_frames = 0 + fps = 0 + index = 0 + + cv2.namedWindow('Multi-Threaded Detection', cv2.WINDOW_NORMAL) + + while True: + frame = video_capture.read() + frame = cv2.flip(frame, 1) + index += 1 + + input_q.put(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + output_frame = output_q.get() + + output_frame = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR) + + elapsed_time = (datetime.datetime.now() - start_time).total_seconds() + num_frames += 1 + fps = num_frames / elapsed_time + # print("frame ", index, num_frames, elapsed_time, fps) + + if (output_frame is not None): + if (args.display > 0): + if (args.fps > 0): + detector_utils.draw_fps_on_image("FPS : " + str(int(fps)), + output_frame) + cv2.imshow('Multi-Threaded Detection', output_frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + else: + if (num_frames == 400): + num_frames = 0 + start_time = datetime.datetime.now() + else: + print("frames processed: ", index, "elapsed time: ", + elapsed_time, "fps: ", str(int(fps))) + else: + # print("video end") + break + elapsed_time = (datetime.datetime.now() - start_time).total_seconds() + fps = num_frames / elapsed_time + print("fps", fps) + pool.terminate() + video_capture.stop() + cv2.destroyAllWindows() diff --git a/detect_single_threaded.py.old b/detect_single_threaded.py.old new file mode 100644 index 0000000..4800081 --- /dev/null +++ b/detect_single_threaded.py.old @@ -0,0 +1,121 @@ +from utils import detector_utils as detector_utils +import cv2 +import tensorflow as tf +import datetime +import argparse + +detection_graph, sess = detector_utils.load_inference_graph() + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument( + '-sth', + '--scorethreshold', + dest='score_thresh', + type=float, + default=0.2, + help='Score threshold for displaying bounding boxes') + parser.add_argument( + '-fps', + '--fps', + dest='fps', + type=int, + default=1, + help='Show FPS on detection/display visualization') + parser.add_argument( + '-src', + '--source', + dest='video_source', + default=0, + help='Device index of the camera.') + parser.add_argument( + '-wd', + '--width', + dest='width', + type=int, + default=320, + help='Width of the frames in the video stream.') + parser.add_argument( + '-ht', + '--height', + dest='height', + type=int, + default=180, + help='Height of the frames in the video stream.') + parser.add_argument( + '-ds', + '--display', + dest='display', + type=int, + default=1, + help='Display the detected images using OpenCV. This reduces FPS') + parser.add_argument( + '-num-w', + '--num-workers', + dest='num_workers', + type=int, + default=4, + help='Number of workers.') + parser.add_argument( + '-q-size', + '--queue-size', + dest='queue_size', + type=int, + default=5, + help='Size of the queue.') + args = parser.parse_args() + + cap = cv2.VideoCapture(args.video_source) + cap.set(cv2.CAP_PROP_FRAME_WIDTH, args.width) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, args.height) + + start_time = datetime.datetime.now() + num_frames = 0 + im_width, im_height = (cap.get(3), cap.get(4)) + # max number of hands we want to detect/track + num_hands_detect = 2 + + cv2.namedWindow('Single-Threaded Detection', cv2.WINDOW_NORMAL) + + while True: + # Expand dimensions since the model expects images to have shape: [1, None, None, 3] + ret, image_np = cap.read() + # image_np = cv2.flip(image_np, 1) + try: + image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) + except: + print("Error converting to RGB") + + # Actual detection. Variable boxes contains the bounding box cordinates for hands detected, + # while scores contains the confidence for each of these boxes. + # Hint: If len(boxes) > 1 , you may assume you have found atleast one hand (within your score threshold) + + boxes, scores = detector_utils.detect_objects(image_np, + detection_graph, sess) + + # draw bounding boxes on frame + detector_utils.draw_box_on_image(num_hands_detect, args.score_thresh, + scores, boxes, im_width, im_height, + image_np) + + # Calculate Frames per second (FPS) + num_frames += 1 + elapsed_time = (datetime.datetime.now() - start_time).total_seconds() + fps = num_frames / elapsed_time + + if (args.display > 0): + # Display FPS on frame + if (args.fps > 0): + detector_utils.draw_fps_on_image("FPS : " + str(int(fps)), + image_np) + + cv2.imshow('Single-Threaded Detection', + cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) + + if cv2.waitKey(25) & 0xFF == ord('q'): + cv2.destroyAllWindows() + break + else: + print("frames processed: ", num_frames, "elapsed time: ", + elapsed_time, "fps: ", str(int(fps))) diff --git a/report.txt b/report.txt new file mode 100644 index 0000000..9adade7 --- /dev/null +++ b/report.txt @@ -0,0 +1,17 @@ +TensorFlow 2.0 Upgrade Script +----------------------------- +Converted 1 files +Detected 0 issues that require attention +-------------------------------------------------------------------------------- +================================================================================ +Detailed log follows: + +================================================================================ +-------------------------------------------------------------------------------- +Processing file 'detect_single_threaded_v1.py' + outputting to 'detect_single_threaded.py' +-------------------------------------------------------------------------------- + + +-------------------------------------------------------------------------------- + diff --git a/reportfile.txt b/reportfile.txt new file mode 100644 index 0000000..128fb61 --- /dev/null +++ b/reportfile.txt @@ -0,0 +1,38 @@ +TensorFlow 2.0 Upgrade Script +----------------------------- +Converted 3 files +Detected 0 issues that require attention +-------------------------------------------------------------------------------- +================================================================================ +Detailed log follows: + +================================================================================ +================================================================================ +Input tree: 'utils-v1/' +================================================================================ +-------------------------------------------------------------------------------- +Processing file 'utils-v1/detector_utils.py' + outputting to 'utils/detector_utils.py' +-------------------------------------------------------------------------------- + +41:23: INFO: Renamed 'tf.GraphDef' to 'tf.compat.v1.GraphDef' +42:13: INFO: Renamed 'tf.gfile.GFile' to 'tf.io.gfile.GFile' +46:15: INFO: Renamed 'tf.Session' to 'tf.compat.v1.Session' +-------------------------------------------------------------------------------- + +-------------------------------------------------------------------------------- +Processing file 'utils-v1/label_map_util.py' + outputting to 'utils/label_map_util.py' +-------------------------------------------------------------------------------- + +116:9: INFO: Renamed 'tf.gfile.GFile' to 'tf.io.gfile.GFile' +-------------------------------------------------------------------------------- + +-------------------------------------------------------------------------------- +Processing file 'utils-v1/__init__.py' + outputting to 'utils/__init__.py' +-------------------------------------------------------------------------------- + + +-------------------------------------------------------------------------------- + diff --git a/utils-v1/__init__.py b/utils-v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils-v1/detector_utils.py b/utils-v1/detector_utils.py new file mode 100644 index 0000000..067cdd1 --- /dev/null +++ b/utils-v1/detector_utils.py @@ -0,0 +1,135 @@ +# Utilities for object detector. + +import numpy as np +import sys +import tensorflow as tf +import os +from threading import Thread +from datetime import datetime +import cv2 +from utils import label_map_util +from collections import defaultdict + + +detection_graph = tf.Graph() +sys.path.append("..") + +# score threshold for showing bounding boxes. +_score_thresh = 0.27 + +MODEL_NAME = 'hand_inference_graph' +# Path to frozen detection graph. This is the actual model that is used for the object detection. +PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' +# List of the strings that is used to add correct label for each box. +PATH_TO_LABELS = os.path.join(MODEL_NAME, 'hand_label_map.pbtxt') + +NUM_CLASSES = 1 +# load label map +label_map = label_map_util.load_labelmap(PATH_TO_LABELS) +categories = label_map_util.convert_label_map_to_categories( + label_map, max_num_classes=NUM_CLASSES, use_display_name=True) +category_index = label_map_util.create_category_index(categories) + + +# Load a frozen infrerence graph into memory +def load_inference_graph(): + + # load frozen tensorflow model into memory + print("> ====== loading HAND frozen graph into memory") + detection_graph = tf.Graph() + with detection_graph.as_default(): + od_graph_def = tf.GraphDef() + with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: + serialized_graph = fid.read() + od_graph_def.ParseFromString(serialized_graph) + tf.import_graph_def(od_graph_def, name='') + sess = tf.Session(graph=detection_graph) + print("> ====== Hand Inference graph loaded.") + return detection_graph, sess + + +# draw the detected bounding boxes on the images +# You can modify this to also draw a label. +def draw_box_on_image(num_hands_detect, score_thresh, scores, boxes, im_width, im_height, image_np): + for i in range(num_hands_detect): + if (scores[i] > score_thresh): + (left, right, top, bottom) = (boxes[i][1] * im_width, boxes[i][3] * im_width, + boxes[i][0] * im_height, boxes[i][2] * im_height) + p1 = (int(left), int(top)) + p2 = (int(right), int(bottom)) + cv2.rectangle(image_np, p1, p2, (77, 255, 9), 3, 1) + + +# Show fps value on image. +def draw_fps_on_image(fps, image_np): + cv2.putText(image_np, fps, (20, 50), + cv2.FONT_HERSHEY_SIMPLEX, 0.75, (77, 255, 9), 2) + + +# Actual detection .. generate scores and bounding boxes given an image +def detect_objects(image_np, detection_graph, sess): + # Definite input and output Tensors for detection_graph + image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') + # Each box represents a part of the image where a particular object was detected. + detection_boxes = detection_graph.get_tensor_by_name( + 'detection_boxes:0') + # Each score represent how level of confidence for each of the objects. + # Score is shown on the result image, together with the class label. + detection_scores = detection_graph.get_tensor_by_name( + 'detection_scores:0') + detection_classes = detection_graph.get_tensor_by_name( + 'detection_classes:0') + num_detections = detection_graph.get_tensor_by_name( + 'num_detections:0') + + image_np_expanded = np.expand_dims(image_np, axis=0) + + (boxes, scores, classes, num) = sess.run( + [detection_boxes, detection_scores, + detection_classes, num_detections], + feed_dict={image_tensor: image_np_expanded}) + return np.squeeze(boxes), np.squeeze(scores) + + +# Code to thread reading camera input. +# Source : Adrian Rosebrock +# https://www.pyimagesearch.com/2017/02/06/faster-video-file-fps-with-cv2-videocapture-and-opencv/ +class WebcamVideoStream: + def __init__(self, src, width, height): + # initialize the video camera stream and read the first frame + # from the stream + self.stream = cv2.VideoCapture(src) + self.stream.set(cv2.CAP_PROP_FRAME_WIDTH, width) + self.stream.set(cv2.CAP_PROP_FRAME_HEIGHT, height) + (self.grabbed, self.frame) = self.stream.read() + + # initialize the variable used to indicate if the thread should + # be stopped + self.stopped = False + + def start(self): + # start the thread to read frames from the video stream + Thread(target=self.update, args=()).start() + return self + + def update(self): + # keep looping infinitely until the thread is stopped + while True: + # if the thread indicator variable is set, stop the thread + if self.stopped: + return + + # otherwise, read the next frame from the stream + (self.grabbed, self.frame) = self.stream.read() + + def read(self): + # return the frame most recently read + return self.frame + + def size(self): + # return size of the capture device + return self.stream.get(3), self.stream.get(4) + + def stop(self): + # indicate that the thread should be stopped + self.stopped = True diff --git a/utils-v1/label_map_util.py b/utils-v1/label_map_util.py new file mode 100644 index 0000000..be67eab --- /dev/null +++ b/utils-v1/label_map_util.py @@ -0,0 +1,140 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Label map utility functions.""" + +import logging + +import tensorflow as tf +from google.protobuf import text_format +from protos import string_int_label_map_pb2 + + +def _validate_label_map(label_map): + """Checks if a label map is valid. + + Args: + label_map: StringIntLabelMap to validate. + + Raises: + ValueError: if label map is invalid. + """ + for item in label_map.item: + if item.id < 1: + raise ValueError('Label map ids should be >= 1.') + + +def create_category_index(categories): + """Creates dictionary of COCO compatible categories keyed by category id. + + Args: + categories: a list of dicts, each of which has the following keys: + 'id': (required) an integer id uniquely identifying this category. + 'name': (required) string representing category name + e.g., 'cat', 'dog', 'pizza'. + + Returns: + category_index: a dict containing the same entries as categories, but keyed + by the 'id' field of each category. + """ + category_index = {} + for cat in categories: + category_index[cat['id']] = cat + return category_index + + +def convert_label_map_to_categories(label_map, + max_num_classes, + use_display_name=True): + """Loads label map proto and returns categories list compatible with eval. + + This function loads a label map and returns a list of dicts, each of which + has the following keys: + 'id': (required) an integer id uniquely identifying this category. + 'name': (required) string representing category name + e.g., 'cat', 'dog', 'pizza'. + We only allow class into the list if its id-label_id_offset is + between 0 (inclusive) and max_num_classes (exclusive). + If there are several items mapping to the same id in the label map, + we will only keep the first one in the categories list. + + Args: + label_map: a StringIntLabelMapProto or None. If None, a default categories + list is created with max_num_classes categories. + max_num_classes: maximum number of (consecutive) label indices to include. + use_display_name: (boolean) choose whether to load 'display_name' field + as category name. If False or if the display_name field does not exist, + uses 'name' field as category names instead. + Returns: + categories: a list of dictionaries representing all possible categories. + """ + categories = [] + list_of_ids_already_added = [] + if not label_map: + label_id_offset = 1 + for class_id in range(max_num_classes): + categories.append({ + 'id': class_id + label_id_offset, + 'name': 'category_{}'.format(class_id + label_id_offset) + }) + return categories + for item in label_map.item: + if not 0 < item.id <= max_num_classes: + logging.info('Ignore item %d since it falls outside of requested ' + 'label range.', item.id) + continue + if use_display_name and item.HasField('display_name'): + name = item.display_name + else: + name = item.name + if item.id not in list_of_ids_already_added: + list_of_ids_already_added.append(item.id) + categories.append({'id': item.id, 'name': name}) + return categories + + +def load_labelmap(path): + """Loads label map proto. + + Args: + path: path to StringIntLabelMap proto text file. + Returns: + a StringIntLabelMapProto + """ + with tf.gfile.GFile(path, 'r') as fid: + label_map_string = fid.read() + label_map = string_int_label_map_pb2.StringIntLabelMap() + try: + text_format.Merge(label_map_string, label_map) + except text_format.ParseError: + label_map.ParseFromString(label_map_string) + _validate_label_map(label_map) + return label_map + + +def get_label_map_dict(label_map_path): + """Reads a label map and returns a dictionary of label names to id. + + Args: + label_map_path: path to label_map. + + Returns: + A dictionary mapping label names to id. + """ + label_map = load_labelmap(label_map_path) + label_map_dict = {} + for item in label_map.item: + label_map_dict[item.name] = item.id + return label_map_dict diff --git a/utils/detector_utils.py b/utils/detector_utils.py index 067cdd1..5e804da 100644 --- a/utils/detector_utils.py +++ b/utils/detector_utils.py @@ -38,12 +38,12 @@ def load_inference_graph(): print("> ====== loading HAND frozen graph into memory") detection_graph = tf.Graph() with detection_graph.as_default(): - od_graph_def = tf.GraphDef() - with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: + od_graph_def = tf.compat.v1.GraphDef() + with tf.io.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') - sess = tf.Session(graph=detection_graph) + sess = tf.compat.v1.Session(graph=detection_graph) print("> ====== Hand Inference graph loaded.") return detection_graph, sess diff --git a/utils/label_map_util.py b/utils/label_map_util.py index be67eab..9135e22 100644 --- a/utils/label_map_util.py +++ b/utils/label_map_util.py @@ -113,7 +113,7 @@ def load_labelmap(path): Returns: a StringIntLabelMapProto """ - with tf.gfile.GFile(path, 'r') as fid: + with tf.io.gfile.GFile(path, 'r') as fid: label_map_string = fid.read() label_map = string_int_label_map_pb2.StringIntLabelMap() try: