Skip to content

Commit

Permalink
add mot utils (#122)
Browse files Browse the repository at this point in the history
* add mot dataset creation utils

* add docs for mot utils

* add tests for mot utils

* update workflows for mot utils

* add norfair to test dependencies

* update docs

* clean code
  • Loading branch information
fcakyon authored Jun 10, 2021
1 parent 82bf1ba commit 5b63a63
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 7 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ jobs:
if: matrix.operating-system == 'macos-latest'
run: pip install torch==1.8.1 torchvision==0.9.1

- name: Install MMDetection(2.12.0) and YOLOv5(5.0.5)
- name: Install MMDetection(2.12.0), YOLOv5(5.0.5) and Norfair(0.3.0)
run: >
pip install mmcv-full==1.3.4 mmdet==2.12.0 yolov5==5.0.5
pip install mmcv-full==1.3.4 mmdet==2.12.0 yolov5==5.0.5 norfair==0.3.0
- name: Lint with flake8
run: |
Expand All @@ -88,7 +88,7 @@ jobs:
run: >
pip install -e .
- name: Test SAHI scripts for MMDetection(2.12.0) and YOLOv5(5.0.5)
- name: Test SAHI scripts for MMDetection(2.12.0), YOLOv5(5.0.5) and Norfair(0.3.0)
run: |
# predict mmdet
python scripts/predict.py --source tests/data/ --model_path tests/data/models/mmdet_retinanet/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth --config_path tests/data/models/mmdet_retinanet/retinanet_r50_fpn_1x_coco.py
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/package_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ jobs:
if: matrix.operating-system == 'macos-latest'
run: pip install torch==1.8.1 torchvision==0.9.1

- name: Install MMDetection(2.12.0) and YOLOv5(5.0.5)
- name: Install MMDetection(2.12.0), YOLOv5(5.0.5) and Norfair(0.3.0)
run: >
pip install mmcv-full==1.3.4 mmdet==2.12.0 yolov5==5.0.5
pip install mmcv-full==1.3.4 mmdet==2.12.0 yolov5==5.0.5 norfair==0.3.0
- name: Test with unittest
run: |
pip install pytest
python -m unittest
- name: Test SAHI scripts for MMDetection(2.12.0) and YOLOv5(5.0.5)
- name: Test SAHI scripts for MMDetection(2.12.0), YOLOv5(5.0.5) and Norfair(0.3.0)
run: |
# predict mmdet
python scripts/predict.py --source tests/data/ --model_path tests/data/models/mmdet_retinanet/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth --config_path tests/data/models/mmdet_retinanet/retinanet_r50_fpn_1x_coco.py
Expand Down
58 changes: 58 additions & 0 deletions docs/MOT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# MOT Utilities

## MOT dataset creation steps:

- import required classes:

```python
from sahi.utils.mot import MotAnnotation, MotFrame, MotVideo
```

- init video:

```python
mot_video = MotVideo(export_dir="mot_video")
```

- init first frame:

```python
mot_frame = MotFrame()
```

- add annotations to frame:

```python
mot_frame.add_annotation(
MotAnnotation(bbox=[x_min, y_min, width, height])
)

mot_frame.add_annotation(
MotAnnotation(bbox=[x_min, y_min, width, height])
)
```

- add frame to video:

```python
mot_video.add_frame(mot_frame)
```

- after adding all frames, your MOT formatted files are ready at `mot_video/` folder.

## Advanced MOT dataset creation:

- you can customize tracker while initializing mot video object:

```python
tracker_params = {
'max_distance_between_points': 30,
'min_detection_threshold': 0,
'hit_inertia_min': 10,
'hit_inertia_max': 12,
'point_transience': 4,
}
# for details: https://github.com/tryolabs/norfair/tree/master/docs#arguments

mot_video = MotVideo(export_dir="mot_video", tracker_kwargs=tracker_params)
```
202 changes: 202 additions & 0 deletions sahi/utils/mot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import os
from pathlib import Path
from typing import Optional, List, Dict

import numpy as np
from sahi.utils.file import increment_path


try:
import norfair
from norfair import Tracker, Detection
from norfair.metrics import PredictionsTextFile, InformationFile
except ImportError:
raise ImportError('Please run "pip install -U norfair" to install norfair first for MOT format handling.')


class GroundTruthTextFile(PredictionsTextFile):
def __init__(self, save_path="."):

predictions_folder = os.path.join(save_path, "gt")
if not os.path.exists(predictions_folder):
os.makedirs(predictions_folder)

self.out_file_name = os.path.join(predictions_folder, "gt" + ".txt")

self.frame_number = 1

def update(self, predictions, frame_number=None):
if frame_number is None:
frame_number = self.frame_number
"""
Write tracked object information in the output file (for this frame), in the format
frame_number, id, bb_left, bb_top, bb_width, bb_height, 1, -1, -1, -1
"""
text_file = open(self.out_file_name, "a+")

for obj in predictions:
frame_str = str(int(frame_number))
id_str = str(int(obj.id))
bb_left_str = str((obj.estimate[0, 0]))
bb_top_str = str((obj.estimate[0, 1])) # [0,1]
bb_width_str = str((obj.estimate[1, 0] - obj.estimate[0, 0]))
bb_height_str = str((obj.estimate[1, 1] - obj.estimate[0, 1]))
row_text_out = (
frame_str
+ ","
+ id_str
+ ","
+ bb_left_str
+ ","
+ bb_top_str
+ ","
+ bb_width_str
+ ","
+ bb_height_str
+ ",1,-1,-1,-1"
)
text_file.write(row_text_out)
text_file.write("\n")

self.frame_number += 1

text_file.close()


def euclidean_distance(detection, tracked_object):
return np.linalg.norm(detection.points - tracked_object.estimate)


class MotAnnotation:
def __init__(self, bbox: List[int], score: Optional[float] = 1):
"""
Args:
bbox (List[int]): [x_min, y_min, width, height]
score (Optional[float])
"""
self.bbox = bbox
self.score = score


class MotFrame:
def __init__(self):
self.annotation_list: List[MotAnnotation] = []

def add_annotation(self, detection: MotAnnotation):
assert type(detection) == MotAnnotation, "'detection' should be a MotAnnotation object."
self.annotation_list.append(detection)

def to_norfair_detections(self, track_points: str = "bbox"):
"""
Args:
track_points (str): 'centroid' or 'bbox'. Defaults to 'bbox'.
"""
norfair_detections: List[Detection] = []
# convert all detections to norfair detections
for annotation in self.annotation_list:
# calculate bbox points
xmin = annotation.bbox[0]
ymin = annotation.bbox[1]
xmax = annotation.bbox[0] + annotation.bbox[2]
ymax = annotation.bbox[1] + annotation.bbox[3]
scores = None
# calculate points as bbox or centroid
if track_points == "bbox":
points = np.array([[xmin, ymin], [xmax, ymax]]) # bbox
if annotation.score is not None:
scores = np.array([annotation.score, annotation.score])

elif track_points == "centroid":
points = np.array([(xmin + xmax) / 2, (ymin + ymax) / 2]) # centroid
if annotation.score is not None:
scores = np.array([annotation.score])
else:
ValueError("'track_points' should be one of ['centroid', 'bbox'].")
# create norfair formatted detection
norfair_detections.append(Detection(points=points, scores=scores))
return norfair_detections


class MotVideo:
def __init__(
self, export_dir: str = "runs/mot", track_points: str = "bbox", tracker_kwargs: Optional[Dict] = dict()
):
"""
Args
export_dir (str): Folder directory that will contain gt/gt.txt and seqinfo.ini
For details: https://github.com/tryolabs/norfair/issues/42#issuecomment-819211873
track_points (str): Track detections based on 'centroid' or 'bbox'. Defaults to 'bbox'.
tracker_kwargs (dict): a dict contains the tracker keys as below:
- max_distance_between_points (int)
- min_detection_threshold (float)
- hit_inertia_min (int)
- hit_inertia_max (int)
- point_transience (int)
For details: https://github.com/tryolabs/norfair/tree/master/docs#arguments
"""

self.export_dir: str = str(increment_path(Path(export_dir), exist_ok=False))
self.track_points: str = track_points

self.groundtruth_text_file: Optional[GroundTruthTextFile] = None
self.tracker: Optional[Tracker] = None

self._create_gt_file()
self._init_tracker(
tracker_kwargs.get("max_distance_between_points", 30),
tracker_kwargs.get("min_detection_threshold", 0),
tracker_kwargs.get("hit_inertia_min", 10),
tracker_kwargs.get("hit_inertia_max", 12),
tracker_kwargs.get("point_transience", 4),
)

def _create_info_file(self, seq_length: int):
"""
Args:
seq_length (int): Number of frames present in video (seqLength parameter in seqinfo.ini)
For details: https://github.com/tryolabs/norfair/issues/42#issuecomment-819211873
"""
# set file path
filepath = Path(self.export_dir) / "seqinfo.ini"
# create folder directory if not exists
filepath.parent.mkdir(exist_ok=True)
# create seqinfo.ini file with seqLength
with open(str(filepath), "w") as file:
file.write(f"seqLength={seq_length}")

def _create_gt_file(self):
self.groundtruth_text_file = GroundTruthTextFile(save_path=self.export_dir)

def _init_tracker(
self,
max_distance_between_points: int = 30,
min_detection_threshold: float = 0,
hit_inertia_min: int = 10,
hit_inertia_max: int = 12,
point_transience: int = 4,
):
"""
Args
max_distance_between_points (int)
min_detection_threshold (float)
hit_inertia_min (int)
hit_inertia_max (int)
point_transience (int)
For details: https://github.com/tryolabs/norfair/tree/master/docs#arguments
"""
self.tracker = Tracker(
distance_function=euclidean_distance,
initialization_delay=0,
distance_threshold=max_distance_between_points,
detection_threshold=min_detection_threshold,
hit_inertia_min=hit_inertia_min,
hit_inertia_max=hit_inertia_max,
point_transience=point_transience,
)

def add_frame(self, frame: MotFrame):
assert type(frame) == MotFrame, "'frame' should be a MotFrame object."
norfair_detections: List[Detection] = frame.to_norfair_detections(track_points=self.track_points)
tracked_objects = self.tracker.update(detections=norfair_detections)
self.groundtruth_text_file.update(predictions=tracked_objects)
self._create_info_file(seq_length=self.groundtruth_text_file.frame_number)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_version():
python_requires=">=3.7",
install_requires=get_requirements(),
extras_require={
"tests": ["pytest", "mmdet"],
"tests": ["pytest", "mmdet", "norfair"],
"dev": ["black==21.5b1", "flake==3.9.2", "isort==5.8.0", "jupyterlab==3.0.14"],
},
classifiers=[
Expand Down
29 changes: 29 additions & 0 deletions tests/test_motutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# OBSS SAHI Tool
# Code written by Fatih C Akyon, 2020.

import os
import shutil
import unittest


class TestMotUtils(unittest.TestCase):
def test_mot_vid(self):
from sahi.utils.mot import MotAnnotation, MotFrame, MotVideo

mot_video = MotVideo(export_dir="tests/data/mot/")
# frame 0
mot_frame = MotFrame()
mot_detection = MotAnnotation(bbox=[10, 10, 100, 100])
mot_frame.add_annotation(mot_detection)
mot_video.add_frame(mot_frame)
# frame 1
mot_frame = MotFrame()
mot_detection = MotAnnotation(bbox=[12, 12, 98, 98])
mot_frame.add_annotation(mot_detection)
mot_detection = MotAnnotation(bbox=[95, 95, 98, 98])
mot_frame.add_annotation(mot_detection)
mot_video.add_frame(mot_frame)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5b63a63

Please sign in to comment.