Skip to content

Commit

Permalink
Merge pull request #1677 from mikel-brostrom/association-first-frame-…
Browse files Browse the repository at this point in the history
…setup

asso func created initalized on first frame
  • Loading branch information
mikel-brostrom authored Sep 26, 2024
2 parents 02b94d8 + 38da1b6 commit 9acd77e
Show file tree
Hide file tree
Showing 12 changed files with 331 additions and 280 deletions.
4 changes: 2 additions & 2 deletions boxmot/configs/botsort.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ appearance_thresh:

cmc_method:
type: choice
default: sof # from the default parameters
options: [sof, cmc]
default: ecc # from the default parameters
options: [sof, ecc]
2 changes: 1 addition & 1 deletion boxmot/configs/ocsort.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ delta_t:
asso_func:
type: choice
default: iou # from the default parameters
options: ['iou', 'giou', 'centroid']
options: ['iou', 'giou']

use_byte:
type: choice
Expand Down
29 changes: 28 additions & 1 deletion boxmot/trackers/basetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import colorsys
from abc import ABC, abstractmethod
from boxmot.utils import logger as LOGGER
from boxmot.utils.iou import AssociationFunction


class BaseTracker(ABC):
Expand All @@ -15,7 +16,8 @@ def __init__(
iou_threshold: float = 0.3,
max_obs: int = 50,
nr_classes: int = 80,
per_class: bool = False
per_class: bool = False,
asso_func: str = 'iou'
):
"""
Initialize the BaseTracker object with detection threshold, maximum age, minimum hits,
Expand All @@ -39,10 +41,12 @@ def __init__(
self.nr_classes = nr_classes
self.iou_threshold = iou_threshold
self.last_emb_size = None
self.asso_func_name = asso_func

self.frame_count = 0
self.active_tracks = [] # This might be handled differently in derived classes
self.per_class_active_tracks = None
self._first_frame_processed = False # Flag to track if the first frame has been processed

# Initialize per-class active tracks
if self.per_class:
Expand Down Expand Up @@ -92,6 +96,29 @@ def get_class_dets_n_embs(self, dets, embs, cls_id):
class_embs = None
return class_dets, class_embs

@staticmethod
def on_first_frame_setup(method):
"""
Decorator to perform setup on the first frame only.
This ensures that initialization tasks (like setting the association function) only
happen once, on the first frame, and are skipped on subsequent frames.
"""
def wrapper(self, *args, **kwargs):
# If setup hasn't been done yet, perform it
if not self._first_frame_processed:
img = args[1]
self.h, self.w = img.shape[0:2]
self.asso_func = AssociationFunction(w=self.w, h=self.h, asso_mode=self.asso_func_name).asso_func

# Mark that the first frame setup has been done
self._first_frame_processed = True

# Call the original method (e.g., update)
return method(self, *args, **kwargs)

return wrapper


@staticmethod
def per_class_decorator(update_method):
"""
Expand Down
5 changes: 3 additions & 2 deletions boxmot/trackers/botsort/botsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
match_thresh: float = 0.8,
proximity_thresh: float = 0.5,
appearance_thresh: float = 0.25,
cmc_method: str = "sof",
cmc_method: str = "ecc",
frame_rate=30,
fuse_first_associate: bool = False,
with_reid: bool = True,
Expand Down Expand Up @@ -81,9 +81,10 @@ def __init__(
weights=reid_weights, device=device, half=half
).model

self.cmc = get_cmc_method('ecc')()
self.cmc = get_cmc_method(cmc_method)()
self.fuse_first_associate = fuse_first_associate

@BaseTracker.on_first_frame_setup
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:
self.check_inputs(dets, img)
Expand Down
1 change: 1 addition & 0 deletions boxmot/trackers/bytetrack/byte_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
self.max_time_lost = self.buffer_size
self.kalman_filter = KalmanFilterXYAH()

@BaseTracker.on_first_frame_setup
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray = None, embs: np.ndarray = None) -> np.ndarray:

Expand Down
6 changes: 3 additions & 3 deletions boxmot/trackers/deepocsort/deep_ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from boxmot.motion.kalman_filters.xysr_kf import KalmanFilterXYSR
from boxmot.motion.kalman_filters.xywh_kf import KalmanFilterXYWH
from boxmot.utils.association import associate, linear_assignment
from boxmot.utils.iou import get_asso_func
from boxmot.trackers.basetracker import BaseTracker
from boxmot.utils.ops import xyxy2xysr

Expand Down Expand Up @@ -272,7 +271,7 @@ def __init__(
Q_s_scaling: float = 0.0001,
**kwargs: dict
):
super().__init__(max_age=max_age, per_class=per_class)
super().__init__(max_age=max_age, per_class=per_class, asso_func=asso_func)
"""
Sets key parameters for SORT
"""
Expand All @@ -281,7 +280,7 @@ def __init__(
self.iou_threshold = iou_threshold
self.det_thresh = det_thresh
self.delta_t = delta_t
self.asso_func = get_asso_func(asso_func)
self.asso_func = asso_func
self.inertia = inertia
self.w_association_emb = w_association_emb
self.alpha_fixed_emb = alpha_fixed_emb
Expand All @@ -300,6 +299,7 @@ def __init__(
self.cmc_off = cmc_off
self.aw_off = aw_off

@BaseTracker.on_first_frame_setup
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:
"""
Expand Down
5 changes: 2 additions & 3 deletions boxmot/trackers/hybridsort/hybridsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from boxmot.trackers.hybridsort.association import (
associate_4_points_with_score, associate_4_points_with_score_with_reid,
cal_score_dif_batch_two_score, embedding_distance, linear_assignment)
from boxmot.utils.iou import get_asso_func
from boxmot.trackers.basetracker import BaseTracker


Expand Down Expand Up @@ -353,7 +352,7 @@ class HybridSORT(BaseTracker):
"""
def __init__(self, reid_weights, device, half, det_thresh, per_class=False, max_age=30, min_hits=3,
iou_threshold=0.3, delta_t=3, asso_func="iou", inertia=0.2, longterm_reid_weight=0, TCM_first_step_weight=0, use_byte=False):
super().__init__(max_age=max_age, per_class=per_class)
super().__init__(max_age=max_age, per_class=per_class, asso_func=asso_func)

"""
Sets key parameters for SORT
Expand All @@ -365,7 +364,6 @@ def __init__(self, reid_weights, device, half, det_thresh, per_class=False, max_
self.frame_count: int = 0
self.det_thresh: float = det_thresh
self.delta_t: int = delta_t
self.asso_func: str = get_asso_func(asso_func) # assuming get_asso_func returns a callable function
self.inertia: float = inertia
self.use_byte: bool = use_byte
self.low_thresh: float = 0.1
Expand Down Expand Up @@ -394,6 +392,7 @@ def camera_update(self, trackers, warp_matrix):
for tracker in trackers:
tracker.camera_update(warp_matrix)

@BaseTracker.on_first_frame_setup
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:
"""
Expand Down
1 change: 1 addition & 0 deletions boxmot/trackers/imprassoc/impr_assoc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
self.cmc = SOF()


@BaseTracker.on_first_frame_setup
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:
self.check_inputs(dets, img)
Expand Down
8 changes: 3 additions & 5 deletions boxmot/trackers/ocsort/ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

from boxmot.motion.kalman_filters.xysr_kf import KalmanFilterXYSR
from boxmot.utils.association import associate, linear_assignment
from boxmot.utils.iou import get_asso_func
from boxmot.utils.iou import run_asso_func
from boxmot.trackers.basetracker import BaseTracker
from boxmot.utils.ops import xyxy2xysr

Expand Down Expand Up @@ -212,7 +210,7 @@ def __init__(
Q_xy_scaling: float = 0.01,
Q_s_scaling: float = 0.0001
):
super().__init__(max_age=max_age, per_class=per_class)
super().__init__(max_age=max_age, per_class=per_class, asso_func=asso_func)
"""
Sets key parameters for SORT
"""
Expand All @@ -223,13 +221,13 @@ def __init__(
self.frame_count = 0
self.det_thresh = det_thresh
self.delta_t = delta_t
self.asso_func = get_asso_func(asso_func)
self.inertia = inertia
self.use_byte = use_byte
self.Q_xy_scaling = Q_xy_scaling
self.Q_s_scaling = Q_s_scaling
KalmanBoxTracker.count = 0

@BaseTracker.on_first_frame_setup
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:
"""
Expand Down Expand Up @@ -327,7 +325,7 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) ->
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
left_dets = dets[unmatched_dets]
left_trks = last_boxes[unmatched_trks]
iou_left = run_asso_func(self.asso_func, left_dets, left_trks, w, h)
iou_left = self.asso_func(left_dets, left_trks)
iou_left = np.array(iou_left)
if iou_left.max() > self.asso_threshold:
"""
Expand Down
8 changes: 4 additions & 4 deletions boxmot/utils/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from boxmot.utils.iou import iou_batch, centroid_batch, run_asso_func
from boxmot.utils.iou import AssociationFunction


def speed_direction_batch(dets, tracks):
Expand Down Expand Up @@ -40,7 +40,7 @@ def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
np.empty((0, 5), dtype=int),
)

iou_matrix = iou_batch(detections, trackers)
iou_matrix = AssociationFunction.iou_batch(detections, trackers)

if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
Expand Down Expand Up @@ -143,7 +143,7 @@ def associate(
valid_mask = np.ones(previous_obs.shape[0])
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0

iou_matrix = run_asso_func(asso_func, detections, trackers, w, h)
iou_matrix = asso_func(detections, trackers)
#iou_matrix = iou_batch(detections, trackers)
scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
Expand Down Expand Up @@ -235,7 +235,7 @@ def associate_kitti(
"""
Cost from IoU
"""
iou_matrix = iou_batch(detections, trackers)
iou_matrix = AssociationFunction.iou_batch(detections, trackers)

"""
With multiple categories, generate the cost for catgory mismatch
Expand Down
Loading

0 comments on commit 9acd77e

Please sign in to comment.