From 3645cf82b64faa3b673d2f05f48e9612386cd37c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Thu, 26 Sep 2024 16:50:27 +0200 Subject: [PATCH] asso func created initalized on first frame --- boxmot/trackers/basetracker.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/boxmot/trackers/basetracker.py b/boxmot/trackers/basetracker.py index 5a73ae5e7..4987ae7ba 100644 --- a/boxmot/trackers/basetracker.py +++ b/boxmot/trackers/basetracker.py @@ -4,6 +4,7 @@ import colorsys from abc import ABC, abstractmethod from boxmot.utils import logger as LOGGER +from boxmot.utils.iou import AssociationFunction class BaseTracker(ABC): @@ -15,7 +16,8 @@ def __init__( iou_threshold: float = 0.3, max_obs: int = 50, nr_classes: int = 80, - per_class: bool = False + per_class: bool = False, + asso_func: str = 'iou' ): """ Initialize the BaseTracker object with detection threshold, maximum age, minimum hits, @@ -39,10 +41,12 @@ def __init__( self.nr_classes = nr_classes self.iou_threshold = iou_threshold self.last_emb_size = None + self.asso_func_name = asso_func self.frame_count = 0 self.active_tracks = [] # This might be handled differently in derived classes self.per_class_active_tracks = None + self._first_frame_processed = False # Flag to track if the first frame has been processed # Initialize per-class active tracks if self.per_class: @@ -92,6 +96,29 @@ def get_class_dets_n_embs(self, dets, embs, cls_id): class_embs = None return class_dets, class_embs + @staticmethod + def on_first_frame_setup(method): + """ + Decorator to perform setup on the first frame only. + This ensures that initialization tasks (like setting the association function) only + happen once, on the first frame, and are skipped on subsequent frames. + """ + def wrapper(self, *args, **kwargs): + # If setup hasn't been done yet, perform it + if not self._first_frame_processed: + img = args[1] + self.h, self.w = img.shape[0:2] + self.asso_func = AssociationFunction(w=self.w, h=self.h, asso_mode=self.asso_func_name).asso_func + + # Mark that the first frame setup has been done + self._first_frame_processed = True + + # Call the original method (e.g., update) + return method(self, *args, **kwargs) + + return wrapper + + @staticmethod def per_class_decorator(update_method): """