diff --git a/datadreamer/dataset_annotation/sam2_annotator.py b/datadreamer/dataset_annotation/sam2_annotator.py index 2c50805..ceb55e6 100644 --- a/datadreamer/dataset_annotation/sam2_annotator.py +++ b/datadreamer/dataset_annotation/sam2_annotator.py @@ -101,11 +101,12 @@ def annotate_batch( image_masks = [] for j in range(len(boxes)): - mask, score = masks_batch[i][j], scores_batch[i][j] + mask, score = masks_batch[i][j].astype(np.uint8), scores_batch[i][j] if score < iou_threshold: image_masks.append([]) continue - mask = mask.astype(np.uint8) + if len(mask.shape) == 3: + mask = mask.squeeze(0) polygon = mask_to_polygon(mask) image_masks.append(polygon if len(polygon) != 0 else [])