From 1b5d51bcf069f658e5845a69a610fa2cec1fc041 Mon Sep 17 00:00:00 2001 From: Alexander James Wallar Date: Mon, 10 Feb 2014 00:18:50 +0000 Subject: [PATCH] Start breaking sounds into classes --- locaudio/api.py | 20 +++++++++ locaudio/config.py | 1 + locaudio/db.py | 84 +++++++++++++++++++++++++++++++------ locaudio/detectionserver.py | 44 ++++++++++++++++--- tests/test_server.py | 11 ++++- 5 files changed, 141 insertions(+), 19 deletions(-) diff --git a/locaudio/api.py b/locaudio/api.py index 21740ed..099cc88 100644 --- a/locaudio/api.py +++ b/locaudio/api.py @@ -14,6 +14,7 @@ def __init__(self, host, port): self.pos_url = self.url + "/locations" self.notify_url = self.url + "/notify" self.names_url = self.url + "/names" + self.class_pos_url = self.url + "/class/locations" def make_position_url(self, sound_name): @@ -38,6 +39,25 @@ def get_sound_locations(self, sound_name): return ret_list + def get_class_locations(self, class_name): + req = urllib2.urlopen(self.class_pos_url + "/" + class_name) + location_list = json.loads(req.read()) + ret_list = list() + + for location in location_list: + position = Point( + location["position"]["x"], + location["position"]["y"] + ) + + ret_list.append( + Location(position, location["confidence"]) + ) + + return ret_list + + + def get_names(self): req = urllib2.urlopen(self.names_url) names_dict = json.loads(req.read()) diff --git a/locaudio/config.py b/locaudio/config.py index cef9711..6b59022 100644 --- a/locaudio/config.py +++ b/locaudio/config.py @@ -8,6 +8,7 @@ app.config.from_object(__name__) detection_events = dict() +class_detection_events = dict() new_data = dict() this = sys.modules[__name__] diff --git a/locaudio/db.py b/locaudio/db.py index 53eb5a4..39468e1 100644 --- a/locaudio/db.py +++ b/locaudio/db.py @@ -3,6 +3,8 @@ import fingerprint import config +from multiprocessing import Pool + """ @@ -10,6 +12,7 @@ [ { name: , + class: , fingerprint: , distance: , spl: @@ -94,6 +97,43 @@ def init(): return False +class SimilarityFunction: + + def __init__(self, f_check): + self.f_check = f_check + + + def __call__(self, db_dict): + return { + "conf": fingerprint.get_similarity( + self.f_check, db_dict["fingerprint"] + ), + "name": db_dict["name"], + "class": db_dict["class"] + } + + +def determine_sound_class(conf_list): + + class_conf_dict = dict() + class_count_dict = dict() + + for db_dict in conf_list: + if not db_dict["class"] in class_conf_dict.keys(): + class_conf_dict[db_dict["class"]] = 0.0 + class_count_dict[db_dict["class"]] = 0 + + class_conf_dict[db_dict["class"]] += db_dict["conf"] + class_count_dict[db_dict["class"]] += 1 + + best_match = max( + class_conf_dict.items(), + key=lambda ct: ct[1] / float(class_count_dict[ct[0]]) + ) + + return best_match[0] + + def get_best_matching_print(f_in): """ @@ -107,22 +147,42 @@ def get_best_matching_print(f_in): """ - conn = r.connect(host=HOST, port=PORT, db=DB) + p_pool = Pool(processes=4) - table = list(r.table(FINGERPRINT_TABLE).run(conn)) + try: + conn = r.connect(host=HOST, port=PORT, db=DB) - if len(table) == 0: - raise LookupError("Database is empty") + table = list(r.table(FINGERPRINT_TABLE).run(conn)) - conf_list = map( - lambda info: { - "conf": fingerprint.get_similarity(f_in, info["fingerprint"]), - "name": info["name"] - }, table - ) + if len(table) == 0: + raise LookupError("Database is empty") + + map_func = SimilarityFunction(f_in) + conf_list = p_pool.map(map_func, table) + + best_match = max(conf_list, key=lambda ct: ct["conf"]) + sound_class = determine_sound_class(conf_list) + + return best_match["name"], sound_class, best_match["conf"] + finally: + p_pool.terminate() + + +def get_class_reference_data(class_name): + conn = r.connect(host=HOST, port=PORT, db=DB) + ref_list = list(r.table(FINGERPRINT_TABLE).get_all( + class_name, index=FINGERPRINT_SECONDARY_KEY + ).run(conn)) + + + avg_distance = 0.0 + avg_spl = 0.0 + for ref_data in ref_list: + avg_distance += ref_data["distance"] + avg_spl += ref_data["spl"] - best_match = max(conf_list, key=lambda ct: ct["conf"]) - return best_match["name"], best_match["conf"] + len_ref_list = float(len(list(ref_list))) + return avg_distance / len_ref_list, avg_spl / len_ref_list def get_reference_data(ref_name): diff --git a/locaudio/detectionserver.py b/locaudio/detectionserver.py index 430aa1f..8194f20 100644 --- a/locaudio/detectionserver.py +++ b/locaudio/detectionserver.py @@ -60,20 +60,23 @@ def post_notify(): req_print = json.loads(request.form["fingerprint"]) - sound_name, confidence = db.get_best_matching_print(req_print) + sound_name, sound_class, confidence = db.get_best_matching_print(req_print) if confidence > MIN_CONFIDENCE: if not sound_name in config.detection_events.keys(): config.detection_events[sound_name] = list() + config.class_detection_events[sound_class] = list() config.new_data[sound_name] = True if len(config.detection_events[sound_name]) + 1 >= MAX_NODE_EVENTS: del config.detection_events[sound_name][0] - config.detection_events[sound_name].append( - request_to_detection_event(request.form, confidence) - ) + d_event = request_to_detection_event(request.form, confidence) + + config.detection_events[sound_name].append(d_event) + + config.class_detection_events[sound_class].append(d_event) return jsonify( error=0, @@ -85,7 +88,7 @@ def post_notify(): @config.app.route("/locations/", methods=["GET"]) -def get_sound_positions(sound_name): +def get_sound_locations(sound_name): """ Gets the sound position given the sound name @@ -111,6 +114,37 @@ def get_sound_positions(sound_name): return json.dumps(ret_list) +@config.app.route("/class/locations/", methods=["GET"]) +def get_class_locations(class_name): + """ + + HIGHLY EXPERIMENTAL: DO NOT FUCK WITH! + + """ + + if not class_name in config.class_detection_events.keys(): + return json.dumps([]) + + # I can do the average like this because we are assuming we are trying + # to track the same class of things, so the reference data must be + # similar + radius, spl = db.get_class_reference_data(class_name) + + location_list = tri.determine_sound_locations( + radius, spl, + config.class_detection_events[class_name], + disp=0 + ) + + ret_list = list() + + for location in location_list: + ret_list.append(location.to_dict()) + + return json.dumps(ret_list) + + + @config.app.route("/viewer/", methods=["GET"]) def get_position_viewer(sound_name): """ diff --git a/tests/test_server.py b/tests/test_server.py index fc8c3af..e09fcfe 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -17,6 +17,7 @@ server_addr = socket.gethostbyname(socket.getfqdn()) server_port = 8000 +test_sound_class = "Chicken" test_sound_name = "Cock" loc = api.Locaudio(server_addr, server_port) @@ -72,10 +73,16 @@ def test_server_notify_not_added(self): print "\n=== Server Notify Not Added === :: {0}\n".format(ret_dict) - def test_server_triangulation(self): + def test_server_sound_triangulation(self): pos_list = loc.get_sound_locations(test_sound_name) - print "\n=== Server Triangulation === :: {0}\n".format(pos_list) + print "\n=== Server Sound Triangulation === :: {0}\n".format(pos_list) + + + def test_server_class_triangulation(self): + pos_list = loc.get_class_locations(test_sound_class) + + print "\n=== Server Class Triangulation === :: {0}\n".format(pos_list) def test_names(self):