Skip to content

Commit

Permalink
Merge pull request #1341 from mikel-brostrom/plot-traj
Browse files Browse the repository at this point in the history
Plot trajectories
  • Loading branch information
mikel-brostrom authored Mar 7, 2024
2 parents cfcd196 + db048b8 commit 2e79489
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 72 deletions.
160 changes: 158 additions & 2 deletions boxmot/trackers/basetracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
import numpy as np
import cv2 as cv
import hashlib
import colorsys


class BaseTracker(object):
def __init__(self, det_thresh=0.3, max_age=30, min_hits=3, iou_threshold=0.3):
def __init__(self, det_thresh: float = 0.3, max_age: int = 30, min_hits: int = 3, iou_threshold: float = 0.3):
"""
Initialize the BaseTracker object with detection threshold, maximum age, minimum hits,
and Intersection Over Union (IOU) threshold for tracking objects in video frames.
Parameters:
- det_thresh (float): Detection threshold for considering detections.
- max_age (int): Maximum age of a track before it is considered lost.
- min_hits (int): Minimum number of detection hits before a track is considered confirmed.
- iou_threshold (float): IOU threshold for determining match between detection and tracks.
Attributes:
- frame_count (int): Counter for the frames processed.
- active_tracks (list): List to hold active tracks, may be used differently in subclasses.
"""
self.det_thresh = det_thresh
self.max_age = max_age
self.min_hits = min_hits
Expand All @@ -8,5 +28,141 @@ def __init__(self, det_thresh=0.3, max_age=30, min_hits=3, iou_threshold=0.3):
self.frame_count = 0
self.active_tracks = [] # This might be handled differently in derived classes

def update(self, dets, img, embs=None):
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> None:
"""
Abstract method to update the tracker with new detections for a new frame. This method
should be implemented by subclasses.
Parameters:
- dets (np.ndarray): Array of detections for the current frame.
- img (np.ndarray): The current frame as an image array.
- embs (np.ndarray, optional): Embeddings associated with the detections, if any.
Raises:
- NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError("The update method needs to be implemented by the subclass.")

def id_to_color(self, id: int, saturation: float = 0.75, value: float = 0.95) -> tuple:
"""
Generates a consistent unique BGR color for a given ID using hashing.
Parameters:
- id (int): Unique identifier for which to generate a color.
- saturation (float): Saturation value for the color in HSV space.
- value (float): Value (brightness) for the color in HSV space.
Returns:
- tuple: A tuple representing the BGR color.
"""

# Hash the ID to get a consistent unique value
hash_object = hashlib.sha256(str(id).encode())
hash_digest = hash_object.hexdigest()

# Convert the first few characters of the hash to an integer
# and map it to a value between 0 and 1 for the hue
hue = int(hash_digest[:8], 16) / 0xffffffff

# Convert HSV to RGB
rgb = colorsys.hsv_to_rgb(hue, saturation, value)

# Convert RGB from 0-1 range to 0-255 range and format as hexadecimal
rgb_255 = tuple(int(component * 255) for component in rgb)
hex_color = '#%02x%02x%02x' % rgb_255
# Strip the '#' character and convert the string to RGB integers
rgb = tuple(int(hex_color.strip('#')[i:i+2], 16) for i in (0, 2, 4))

# Convert RGB to BGR for OpenCV
bgr = rgb[::-1]

return bgr

def plot_box_on_img(self, img: np.ndarray, box: tuple, conf: float, cls: int, id: int) -> np.ndarray:
"""
Draws a bounding box with ID, confidence, and class information on an image.
Parameters:
- img (np.ndarray): The image array to draw on.
- box (tuple): The bounding box coordinates as (x1, y1, x2, y2).
- conf (float): Confidence score of the detection.
- cls (int): Class ID of the detection.
- id (int): Unique identifier for the detection.
Returns:
- np.ndarray: The image array with the bounding box drawn on it.
"""

thickness = 2
fontscale = 0.5

img = cv.rectangle(
img,
(int(box[0]), int(box[1])),
(int(box[2]), int(box[3])),
self.id_to_color(id),
thickness
)
img = cv.putText(
img,
f'id: {int(id)}, conf: {conf:.2f}, c: {int(cls)}',
(int(box[0]), int(box[1]) - 10),
cv.FONT_HERSHEY_SIMPLEX,
fontscale,
self.id_to_color(id),
thickness
)
return img


def plot_trackers_trajectories(self, img: np.ndarray, observations: list, id: int) -> np.ndarray:
"""
Draws the trajectories of tracked objects based on historical observations. Each point
in the trajectory is represented by a circle, with the thickness increasing for more
recent observations to visualize the path of movement.
Parameters:
- img (np.ndarray): The image array on which to draw the trajectories.
- observations (list): A list of bounding box coordinates representing the historical
observations of a tracked object. Each observation is in the format (x1, y1, x2, y2).
- id (int): The unique identifier of the tracked object for color consistency in visualization.
Returns:
- np.ndarray: The image array with the trajectories drawn on it.
"""
for i, box in enumerate(observations):
trajectory_thickness = int(np.sqrt(float (i + 1)) * 1.2)
img = cv.circle(
img,
(int((box[0] + box[2]) / 2),
int((box[1] + box[3]) / 2)),
2,
color=self.id_to_color(int(id)),
thickness=trajectory_thickness
)
return img


def plot_results(self, img: np.ndarray, show_trajectories: bool) -> np.ndarray:
"""
Visualizes the trajectories of all active tracks on the image. For each track,
it draws the latest bounding box and the path of movement if the history of
observations is longer than two. This helps in understanding the movement patterns
of each tracked object.
Parameters:
- img (np.ndarray): The image array on which to draw the trajectories and bounding boxes.
Returns:
- np.ndarray: The image array with trajectories and bounding boxes of all active tracks.
"""
for a in self.active_tracks:
if a.history_observations:
if len(a.history_observations) > 2:
box = a.history_observations[-1]
img = self.plot_box_on_img(img, box, a.conf, a.cls, a.id)
if show_trajectories:
img = self.plot_trackers_trajectories(img, a.history_observations, a.id)

return img

44 changes: 23 additions & 21 deletions boxmot/trackers/botsort/bot_sort.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license

from collections import deque

import numpy as np
from collections import deque

from boxmot.appearance.reid_multibackend import ReIDDetectMultiBackend
from boxmot.motion.cmc.sof import SOF
Expand All @@ -21,14 +20,15 @@ class STrack(BaseTrack):
def __init__(self, det, feat=None, feat_history=50):
# wait activate
self.xywh = xyxy2xywh(det[0:4]) # (x1, y1, x2, y2) --> (xc, yc, w, h)
self.score = det[4]
self.conf = det[4]
self.cls = det[5]
self.det_ind = det[6]
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False
self.cls_hist = [] # (cls id, freq)
self.update_cls(self.cls, self.score)
self.update_cls(self.cls, self.conf)
self.history_observations = deque([], maxlen=50)

self.tracklet_len = 0

Expand All @@ -49,23 +49,23 @@ def update_features(self, feat):
self.features.append(feat)
self.smooth_feat /= np.linalg.norm(self.smooth_feat)

def update_cls(self, cls, score):
def update_cls(self, cls, conf):
if len(self.cls_hist) > 0:
max_freq = 0
found = False
for c in self.cls_hist:
if cls == c[0]:
c[1] += score
c[1] += conf
found = True

if c[1] > max_freq:
max_freq = c[1]
self.cls = c[0]
if not found:
self.cls_hist.append([cls, score])
self.cls_hist.append([cls, conf])
self.cls = cls
else:
self.cls_hist.append([cls, score])
self.cls_hist.append([cls, conf])
self.cls = cls

def predict(self):
Expand Down Expand Up @@ -138,11 +138,11 @@ def re_activate(self, new_track, frame_id, new_id=False):
self.frame_id = frame_id
if new_id:
self.id = self.next_id()
self.score = new_track.score
self.conf = new_track.conf
self.cls = new_track.cls
self.det_ind = new_track.det_ind

self.update_cls(new_track.cls, new_track.score)
self.update_cls(new_track.cls, new_track.conf)

def update(self, new_track, frame_id):
"""
Expand All @@ -155,6 +155,8 @@ def update(self, new_track, frame_id):
self.frame_id = frame_id
self.tracklet_len += 1

self.history_observations.append(self.xyxy)

self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, new_track.xywh
)
Expand All @@ -165,10 +167,10 @@ def update(self, new_track, frame_id):
self.state = TrackState.Tracked
self.is_activated = True

self.score = new_track.score
self.conf = new_track.conf
self.cls = new_track.cls
self.det_ind = new_track.det_ind
self.update_cls(new_track.cls, new_track.score)
self.update_cls(new_track.cls, new_track.conf)

@property
def xyxy(self):
Expand Down Expand Up @@ -281,17 +283,17 @@ def update(self, dets, img, embs=None):
else:
detections = []

""" Add newly detected tracklets to tracked_stracks"""
""" Add newly detected tracklets to active_tracks"""
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
active_tracks = [] # type: list[STrack]
for track in self.active_tracks:
if not track.is_activated:
unconfirmed.append(track)
else:
tracked_stracks.append(track)
active_tracks.append(track)

""" Step 2: First association, with high score detection boxes"""
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
""" Step 2: First association, with high conf detection boxes"""
strack_pool = joint_stracks(active_tracks, self.lost_stracks)

# Predict the current location with KF
STrack.multi_predict(strack_pool)
Expand All @@ -301,7 +303,7 @@ def update(self, dets, img, embs=None):
STrack.multi_gmc(strack_pool, warp)
STrack.multi_gmc(unconfirmed, warp)

# Associate with high score detection boxes
# Associate with high conf detection boxes
ious_dists = iou_distance(strack_pool, detections)
ious_dists_mask = ious_dists > self.proximity_thresh
if self.fuse_first_associate:
Expand Down Expand Up @@ -329,7 +331,7 @@ def update(self, dets, img, embs=None):
track.re_activate(det, self.frame_count, new_id=False)
refind_stracks.append(track)

""" Step 3: Second association, with low score detection boxes"""
""" Step 3: Second association, with low conf detection boxes"""
if len(dets_second) > 0:
"""Detections"""
detections_second = [STrack(dets_second) for dets_second in dets_second]
Expand Down Expand Up @@ -386,7 +388,7 @@ def update(self, dets, img, embs=None):
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.new_track_thresh:
if track.conf < self.new_track_thresh:
continue

track.activate(self.kalman_filter, self.frame_count)
Expand Down Expand Up @@ -418,7 +420,7 @@ def update(self, dets, img, embs=None):
output = []
output.extend(t.xyxy)
output.append(t.id)
output.append(t.score)
output.append(t.conf)
output.append(t.cls)
output.append(t.det_ind)
outputs.append(output)
Expand Down
2 changes: 1 addition & 1 deletion boxmot/trackers/bytetrack/basetrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class BaseTrack(object):
history = OrderedDict()
features = []
curr_feature = None
score = 0
conf = 0
start_frame = 0
frame_id = 0
time_since_update = 0
Expand Down
Loading

0 comments on commit 2e79489

Please sign in to comment.