diff --git a/deep_sort/linear_assignment.py b/deep_sort/linear_assignment.py index 178456cf..c095bbc7 100644 --- a/deep_sort/linear_assignment.py +++ b/deep_sort/linear_assignment.py @@ -1,7 +1,7 @@ # vim: expandtab:ts=4:sw=4 from __future__ import absolute_import import numpy as np -from sklearn.utils.linear_assignment_ import linear_assignment +from scipy.optimize import linear_sum_assignment as linear_assignment from . import kalman_filter @@ -56,6 +56,7 @@ def min_cost_matching( tracks, detections, track_indices, detection_indices) cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5 indices = linear_assignment(cost_matrix) + indices = np.hstack([indices[0].reshape(((indices[0].shape[0]), 1)),indices[1].reshape(((indices[0].shape[0]), 1))]) matches, unmatched_tracks, unmatched_detections = [], [], [] for col, detection_idx in enumerate(detection_indices): diff --git a/tools/generate_detections.py b/tools/generate_detections.py index c7192c26..6ac56a2d 100644 --- a/tools/generate_detections.py +++ b/tools/generate_detections.py @@ -72,15 +72,15 @@ class ImageEncoder(object): def __init__(self, checkpoint_filename, input_name="images", output_name="features"): - self.session = tf.Session() - with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle: - graph_def = tf.GraphDef() + self.session = tf.compat.v1.Session() + with tf.compat.v1.gfile.GFile(checkpoint_filename, "rb") as file_handle: + graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(file_handle.read()) - tf.import_graph_def(graph_def, name="net") - self.input_var = tf.get_default_graph().get_tensor_by_name( - "net/%s:0" % input_name) - self.output_var = tf.get_default_graph().get_tensor_by_name( - "net/%s:0" % output_name) + tf.compat.v1.import_graph_def(graph_def, name="net") + self.input_var = tf.compat.v1.get_default_graph().get_tensor_by_name( + "%s:0" % input_name) #"net/%s:0" % input_name + self.output_var = tf.compat.v1.get_default_graph().get_tensor_by_name( + "%s:0" % output_name) #"net/%s:0" % output_name assert len(self.output_var.get_shape()) == 2 assert len(self.input_var.get_shape()) == 4