Skip to content

Commit

Permalink
added functions to the base comptuer vision detector class
Browse files Browse the repository at this point in the history
  • Loading branch information
sahilshah379 committed Sep 29, 2023
1 parent 6264436 commit f7db178
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 33 deletions.
16 changes: 16 additions & 0 deletions ns_vfs/model/vision/_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import supervision as sv
import numpy as np
import abc


Expand Down Expand Up @@ -27,6 +29,20 @@ def get_weight(self):
"""Get weight."""
return self._weight

def get_labels(self) -> list:
"""Return sv.Detections"""
return self._labels

def get_detections(self) -> sv.Detections:
"""Return sv.Detections"""
return self._detection

def get_confidence(self) -> np.ndarray:
return self._confidence

def get_size(self) -> int:
return self._size

@abc.abstractmethod
def detect(self, frame) -> any:
"""Detect object in frame."""
Expand Down
20 changes: 15 additions & 5 deletions ns_vfs/model/vision/grounding_dino.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

import warnings

from groundingdino.util.inference import Model
from omegaconf import DictConfig
import numpy as np
import warnings

from ns_vfs.model.vision._base import ComputerVisionDetector

warnings.filterwarnings("ignore")
import numpy as np


class GroundingDino(ComputerVisionDetector):
Expand Down Expand Up @@ -59,11 +58,22 @@ def detect(self, frame_img: np.ndarray, classes: list) -> any:
Returns:
any: Detections.
"""
detections = self.model.predict_with_classes(
detected_obj = self.model.predict_with_classes(
image=frame_img,
classes=self._parse_class_name(class_names=classes),
box_threshold=self._config.BOX_TRESHOLD,
text_threshold=self._config.TEXT_TRESHOLD,
)

return detections
self._labels = [
f"{classes[class_id] if class_id is not None else None} {confidence:0.2f}"
for _, _, confidence, class_id, _ in detected_obj
]

self._detections = detected_obj

self._confidence = detected_obj.confidence

self._size = len(detected_obj)

return detected_obj
24 changes: 18 additions & 6 deletions ns_vfs/model/vision/yolo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import warnings

from ultralytics import YOLO
from omegaconf import DictConfig
from ultralytics import YOLO
import supervision as sv
import numpy as np
import warnings

from ns_vfs.model.vision._base import ComputerVisionDetector

warnings.filterwarnings("ignore")
import numpy as np


class Yolo(ComputerVisionDetector):
Expand Down Expand Up @@ -58,9 +58,21 @@ def detect(self, frame_img: np.ndarray, classes: list) -> any:
"""
classes_reversed = {v:k for k, v in self.model.names.items()}
class_ids = [classes_reversed[c] for c in classes]
detections = self.model.predict(
detected_obj = self.model.predict(
source=frame_img,
classes=class_ids
)

return detections
self._labels = []
for i in range(len(detected_obj[0].boxes)):
class_id = int(detected_obj[0].boxes.cls[i])
confidence = float(detected_obj[0].boxes.conf[i])
self._labels.append(f"{detected_obj[0].names[class_id] if class_id is not None else None} {confidence:0.2f}")

self._detections = sv.Detections(xyxy=detected_obj[0].boxes.xyxy.cpu().detach().numpy())

self._confidence = detected_obj[0].boxes.conf.cpu().detach().numpy()

self._size = len(detected_obj[0].boxes)

return detected_obj
20 changes: 4 additions & 16 deletions ns_vfs/video_to_automaton.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,29 +64,19 @@ def _sigmoid(self, x, k=1, x0=0) -> float:
def _annotate_frame(
self,
frame_img: np.ndarray,
proposition: list,
detected_obj: any,
output_dir: str | None = None,
) -> None:
"""Annotate frame with bounding box.
Args:
frame_img (np.ndarray): Frame image.
proposition (list): List of propositions.
detected_obj (any): Detected object.
output_dir (str | None, optional): Output directory. Defaults to None.
"""
box_annotator = sv.BoxAnnotator()
labels = []
for i in range(len(detected_obj[0].boxes)):
class_id = int(detected_obj[0].boxes.cls[i])
confidence = float(detected_obj[0].boxes.conf[i])
labels.append(f"{detected_obj[0].names[class_id] if class_id is not None else None} {confidence:0.2f}")

detections = sv.Detections(xyxy=detected_obj[0].boxes.xyxy.cpu().detach().numpy())

annotated_frame = box_annotator.annotate(
scene=frame_img.copy(), detections=detections, labels=labels
scene=frame_img.copy(), detections=self._detector.get_detections(), labels=self._detector.get_labels()
)

sv.plot_image(annotated_frame, (16, 16))
Expand Down Expand Up @@ -161,15 +151,13 @@ def get_probabilistic_proposition_from_frame(
Returns:
float: Probabilistic proposition from frame.
"""
detected_obj = self._detector.detect(frame_img, [proposition])
if len(detected_obj[0].boxes) > 0:
self._detector.detect(frame_img, [proposition])
if self._detector.get_size() > 0:
if is_annotation:
self._annotate_frame(
frame_img=frame_img,
detected_obj=detected_obj,
proposition=[proposition],
)
return self._mapping_probability(np.round(np.max(detected_obj[0].boxes.conf.cpu().detach().numpy()), 2))
return self._mapping_probability(np.round(np.max(self._detector.get_confidence()), 2))
# probability of the object in the frame
else:
return 0 # probability of the object in the frame is 0
Expand Down
16 changes: 10 additions & 6 deletions run_frame_to_automata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

from ns_vfs.config.loader import load_config
from ns_vfs.model.vision.grounding_dino import GroundingDino
from ns_vfs.model.vision.yolo import Yolo
from ns_vfs.processor.video_processor import (
VideoFrameWindowProcessor,
)
from ns_vfs.processor.video_processor import VideoFrameWindowProcessor
from ns_vfs.video_to_automaton import VideotoAutomaton

if __name__ == "__main__":
Expand All @@ -15,9 +14,14 @@
config = load_config()

frame2automaton = VideotoAutomaton(
detector=Yolo(
config=config.YOLO,
weight_path=config.YOLO.YOLO_CHECKPOINT_PATH,
# detector=Yolo(
# config=config.YOLO,
# weight_path=config.YOLO.YOLO_CHECKPOINT_PATH,
# ),
detector=GroundingDino(
config=config.GROUNDING_DINO,
weight_path=config.GROUNDING_DINO.GROUNDING_DINO_CHECKPOINT_PATH,
config_path=config.GROUNDING_DINO.GROUNDING_DINO_CONFIG_PATH,
),
video_processor=VideoFrameWindowProcessor(
video_path=sample_video_path,
Expand Down

0 comments on commit f7db178

Please sign in to comment.