Skip to content

Commit

Permalink
cleaning and adapting yolo detections for any image size
Browse files Browse the repository at this point in the history
  • Loading branch information
adityarauniyar committed Dec 24, 2024
1 parent 48949d3 commit dee4982
Showing 1 changed file with 69 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@

# build path to the model checkpoint
SAM2_MODEL_CHECKPOINT_PATH = str(pathlib.Path(__file__).parent / SEGMENT_ANYTHING_2_REPO_PATH / "checkpoints" / SAM2_MODEL_CHECKPOINT)
logger.debug(f'Model checkpoint: {SAM2_MODEL_CHECKPOINT}')
logger.debug(f'Model config: {SAM2_MODEL_CONFIG}')
logger.info(f'Model checkpoint: {SAM2_MODEL_CHECKPOINT}')
logger.info(f'Model config: {SAM2_MODEL_CONFIG}')
SAM2_PREDICTOR = build_sam2_video_predictor(SAM2_MODEL_CONFIG, SAM2_MODEL_CHECKPOINT_PATH)


Expand Down Expand Up @@ -589,7 +589,14 @@ def add_tracking(self, tracking: ObjectTracking):
"""
self.state.object_trackings[tracking.tracking_id] = tracking

def get_YOLO_detections(self, conf, iou, yolo_model_checkpoint, image_size, video_source_path, video_fps, output_frames_dir="predictions/yolo"):
def get_YOLO_detections(self,
conf,
iou,
yolo_model_checkpoint,
video_source_path,
video_fps,
max_frames_to_track,
output_frames_dir="predictions/yolo"):
"""
Run YOLO detection on the entire video, overlay the detections on each frame, and save each frame as an image.
"""
Expand All @@ -606,6 +613,7 @@ def get_YOLO_detections(self, conf, iou, yolo_model_checkpoint, image_size, vide
# Get the frame width and height from the video
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
logger.info(f"Video dimensions: width={width}, height={height}")

# Ensure output directory exists
if not os.path.exists(output_frames_dir):
Expand All @@ -614,7 +622,7 @@ def get_YOLO_detections(self, conf, iou, yolo_model_checkpoint, image_size, vide
logger.info(f"Reading frames from {video_source_path} at {video_fps} FPS...")

frame_id = 0
frames_to_track = 1
frames_to_track = min(frame_count, max_frames_to_track)

while cap.isOpened() and frame_id < frames_to_track:
# Read from a specific timestamp based on the fps and frame_id
Expand All @@ -625,12 +633,14 @@ def get_YOLO_detections(self, conf, iou, yolo_model_checkpoint, image_size, vide
if not ret:
break # End of video

# Run YOLO detection on the current frame
# Run YOLO detection on the current frame,
# Image size must be multiple of max stride 32
img_size = min(width, height) - (min(width, height) % 32)
detection_result = self.yolo_model.track(
frame,
conf=conf,
iou=iou,
imgsz=image_size,
imgsz=img_size,
persist=True,
show=False,
verbose=False,
Expand Down Expand Up @@ -687,116 +697,82 @@ def get_YOLO_detections(self, conf, iou, yolo_model_checkpoint, image_size, vide
cap.release()
logger.info(f"Frames with detections saved to {output_frames_dir}")

def get_sam_tracking_with_yolo_prompts(self, video_fps: int, frames_to_track: int, sam_batchsize: int = 100, video_source_path: str = None):
def get_sam_tracking_with_yolo_prompts(
self, video_fps: int, frames_to_track: int, sam_batchsize: int = 100, video_source_path: str = None
):
"""
Processes a video using SAM tracking with YOLO prompts.
Args:
video_fps (int): Frames per second for the video.
frames_to_track (int): Total number of frames to process.
sam_batchsize (int): Batch size for SAM processing.
video_source_path (str): Path to the source video.
"""
current_frame = 0
# -1 is done to avoid out of range error as propagation only happens N -1 frames
sam_end_frame = min(self.get_video_fps_duration(video_source_path, fps=video_fps)[0], frames_to_track) - 1

sam_range_start = -1
sam_range_end = -1

frame_count, duration = self.get_video_fps_duration(video_source_path, fps=video_fps)

sam_end_frame = min(frame_count, frames_to_track)
logger.info(f"Processing video with SAM tracking for {sam_end_frame} frames...")

# Split the video into frames
with tempfile.TemporaryDirectory() as temp_img_dir:
# temp_img_dir = '/tmp/frames' # Use persisted directory for debugging
# os.makedirs(temp_img_dir, exist_ok=True)

# # use persisted dir for debug
temp_img_dir = '/tmp/frames'
os.makedirs(temp_img_dir, exist_ok=True)

# get all frames
frames = list(self.split_frames(
video_path=video_source_path, temp_dir=temp_img_dir, video_fps=video_fps,
start_frame=0,
end_frame=sam_end_frame + 1
))

# Set the video source to the temporary image directory
os.path.isdir(video_source_path)
self.video_source = ImageFolderSource(
temp_img_dir, sorting_rule=lambda x: x
frames = list(
self.split_frames(
video_path=video_source_path, temp_dir=temp_img_dir, video_fps=video_fps, start_frame=0, end_frame=sam_end_frame + 1
)
)
self.image_shape = self.get_frame(0).shape[:2]

height, width, _ = frames[0][1].shape
logger.debug(f'Video width={width}, height={height}')
self.video_source = ImageFolderSource(temp_img_dir, sorting_rule=lambda x: x)
self.image_shape = self.get_frame(0).shape[:2]
logger.info(f"Video dimensions: width={self.image_shape[1]}, height={self.image_shape[0]}")

with torch.autocast("cuda", torch.bfloat16):
while current_frame < sam_end_frame - 1:
logger.debug(f"Processing frames using SAM Tracker...")
while current_frame < sam_end_frame:
logger.debug(f"Processing frame {current_frame}")

logger.info(f"Processing frame {current_frame}")

# prepare SAM prompts at current frame
sam_prompts = self.prepare_sam_prompts(current_frame)

if len(sam_prompts) == 0:
logger.debug(f"No SAM prompts for frame {current_frame}")
if not sam_prompts:
logger.debug(f"No SAM prompts for frame {current_frame}. Skipping.")
current_frame += 1
continue # if no prompts is possible, skip the frame

# initialize predictor state if we switched to a new batch
logger.info(f"Processing frame {current_frame} till {sam_range_end}")
# if current_frame >= sam_range_end:
# sam_range_start = current_frame
# sam_range_end = current_frame + sam_batchsize
#
# # limit the end frame to the total frame count
# sam_range_end = min(sam_range_end, frame_count - 1)
#
# image_paths = self.video_source.get_image_paths(
# sam_range_start, sam_range_end + 1
# ) # predict 1 more frame because we want 1 overlap to propagate the masks
# logger.debug(f"Image paths: {image_paths}")
# predictor_state = self.sam2_predictor.init_state(frame_paths=image_paths)
continue

logger.debug(f"Initializing inference state for frame {current_frame}")
inference_state = self.get_inference_state(video_dir=temp_img_dir)
self.sam2_predictor.reset_state(inference_state)

logger.debug(f"Adding {len(sam_prompts)} prompts to SAM predictor")
for tracking_id, mask_bbox in sam_prompts:
# add the prompt to the predictor
_, out_obj_ids, out_mask_logits = self.sam2_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=current_frame - sam_range_start,
frame_idx=current_frame,
obj_id=tracking_id,
box=mask_bbox,
)
logger.debug(f"Added {len(sam_prompts)} prompts to SAM")
# propagate the masks along the video. upon receiving the masks,
# test can current detections be explained by the masks. If there
# are unexplained detections, stop the propagation and add them to
# the SAM prompts.
for (
out_frame_idx,
out_obj_ids,
out_mask_logits,
) in self.sam2_predictor.propagate_in_video(inference_state):

for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_predictor.propagate_in_video(inference_state):
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy()[0]
mask_arr = np.asfortranarray(mask)

if mask_arr.any():
if mask.any():
mask_obj = Mask(
encoded=coco_mask.encode(np.asfortranarray(mask_arr)),
shape=mask_arr.shape,
encoded=coco_mask.encode(np.asfortranarray(mask)),
shape=mask.shape,
)
self.append_mask_to_tracking(out_obj_id, out_frame_idx, mask_obj)

# add the mask to the tracking
self.append_mask_to_tracking(
out_obj_id, out_frame_idx + sam_range_start, mask_obj
)

# explain the detections with the masks
unexplained_detections = self.get_unexplained_detections_at_frame(
out_frame_idx + sam_range_start
)
logger.debug(f"Received {len(unexplained_detections)} unexplained detections")

if len(unexplained_detections) > 0:
break # stop the propagation
# unexplained_detections = self.get_unexplained_detections_at_frame(out_frame_idx)
# logger.debug(f"Unexplained detections at frame {out_frame_idx}: {len(unexplained_detections)}")
#
# if unexplained_detections:
# logger.warning(f"Stopping propagation at frame {out_frame_idx} due to unexplained detections.")
# break

# update the current frame and repeat the process
logger.debug(f"Setting current frame to {out_frame_idx + sam_range_start}")
current_frame = out_frame_idx + sam_range_start
current_frame = out_frame_idx
logger.info(f"Updated current frame to {current_frame}")

def prepare_sam_prompts(self, frame_id: int):
"""
Expand Down Expand Up @@ -993,14 +969,14 @@ def get_regions_from_yolo_sam2_tracker(self,
self.get_YOLO_detections(conf=conf,
iou=iou,
yolo_model_checkpoint=yolo_model_checkpoint,
image_size=image_size,
video_source_path=video_source_path,
video_fps=video_fps)
video_fps=video_fps,
max_frames_to_track=max_frames_to_track)

######### SAM TRACKING #########
self.build_sam2_predictor(sam2_model_checkpoint_path, sam2_model_cfg, sam2_max_frames_to_track, prompt_type, annotation_workaround)
self.build_sam2_predictor(sam2_model_checkpoint_path, sam2_model_cfg, max_frames_to_track, prompt_type, annotation_workaround)
self.get_sam_tracking_with_yolo_prompts(video_fps=video_fps,
frames_to_track=sam2_max_frames_to_track,
frames_to_track=max_frames_to_track,
video_source_path=video_source_path)

# return the regions
Expand All @@ -1020,7 +996,7 @@ def is_control_matched(cls, control: ControlTag) -> bool:
# check object tag type
if control.objects[0].tag != "Video":
return False
if get_bool(control.attr, "model_sam_tracker", "false"):
if not get_bool(control.attr, "model_sam_tracker", "false"):
return False
return True

Expand Down Expand Up @@ -1055,7 +1031,7 @@ def predict_regions(self, path) -> List[Dict]:
image_size=image_size,
sam2_model_cfg=SAM2_MODEL_CONFIG,
sam2_model_checkpoint_path=SAM2_MODEL_CHECKPOINT_PATH,
sam2_max_frames_to_track=frames_to_track,
max_frames_to_track=frames_to_track,
prompt_type=PROMPT_TYPE,
annotation_workaround=ANNOTATION_WORKAROUND,
video_source_path=path,
Expand Down

0 comments on commit dee4982

Please sign in to comment.