From bcf8620fd1fb3d8d29a2333ddcadb88e86f7ef47 Mon Sep 17 00:00:00 2001 From: jasons1425 Date: Sun, 3 Jul 2022 16:03:38 +0800 Subject: [PATCH 1/2] Count every STrack within each individual ByteTracker --- yolox/tracker/basetrack.py | 8 ++++++-- yolox/tracker/byte_tracker.py | 10 ++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/yolox/tracker/basetrack.py b/yolox/tracker/basetrack.py index d5837b05..a5a2a643 100644 --- a/yolox/tracker/basetrack.py +++ b/yolox/tracker/basetrack.py @@ -10,8 +10,9 @@ class TrackState(object): class BaseTrack(object): - def __init__(self): + def __init__(self, count_gen=None): self._count = 0 + self.count_gen = count_gen self.track_id = 0 self.is_activated = False @@ -33,7 +34,10 @@ def end_frame(self): return self.frame_id def next_id(self): - self._count += 1 + if self.count_gen: + self._count = self.count_gen.__next__() + else: + self._count += 1 return self._count def activate(self, *args): diff --git a/yolox/tracker/byte_tracker.py b/yolox/tracker/byte_tracker.py index 2d004599..02b6cdeb 100644 --- a/yolox/tracker/byte_tracker.py +++ b/yolox/tracker/byte_tracker.py @@ -5,6 +5,7 @@ import copy import torch import torch.nn.functional as F +import itertools from .kalman_filter import KalmanFilter from yolox.tracker import matching @@ -12,8 +13,8 @@ class STrack(BaseTrack): shared_kalman = KalmanFilter() - def __init__(self, tlwh, score): - + def __init__(self, tlwh, score, count_gen=None): + super().__init__(count_gen) # wait activate self._tlwh = np.asarray(tlwh, dtype=np.float) self.kalman_filter = None @@ -155,6 +156,7 @@ def __init__(self, args, frame_rate=30): self.buffer_size = int(frame_rate / 30.0 * args.track_buffer) self.max_time_lost = self.buffer_size self.kalman_filter = KalmanFilter() + self.track_count_gen = itertools.count() def update(self, output_results, img_info, img_size): self.frame_id += 1 @@ -186,7 +188,7 @@ def update(self, output_results, img_info, img_size): if len(dets) > 0: '''Detections''' - detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for + detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s, self.track_count_gen) for (tlbr, s) in zip(dets, scores_keep)] else: detections = [] @@ -223,7 +225,7 @@ def update(self, output_results, img_info, img_size): # association the untrack to the low score detections if len(dets_second) > 0: '''Detections''' - detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for + detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s, self.track_count_gen) for (tlbr, s) in zip(dets_second, scores_second)] else: detections_second = [] From b057f438aa5c9ea74960cafba8052033407acc9d Mon Sep 17 00:00:00 2001 From: jasons1425 Date: Sun, 3 Jul 2022 16:21:31 +0800 Subject: [PATCH 2/2] set track_id starts at 1 --- yolox/tracker/byte_tracker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yolox/tracker/byte_tracker.py b/yolox/tracker/byte_tracker.py index 02b6cdeb..b9609e9a 100644 --- a/yolox/tracker/byte_tracker.py +++ b/yolox/tracker/byte_tracker.py @@ -156,7 +156,7 @@ def __init__(self, args, frame_rate=30): self.buffer_size = int(frame_rate / 30.0 * args.track_buffer) self.max_time_lost = self.buffer_size self.kalman_filter = KalmanFilter() - self.track_count_gen = itertools.count() + self.track_count_gen = itertools.count(start=1) def update(self, output_results, img_info, img_size): self.frame_id += 1