diff --git a/samgeo/text_sam.py b/samgeo/text_sam.py index 2617ac00..1c010111 100644 --- a/samgeo/text_sam.py +++ b/samgeo/text_sam.py @@ -337,12 +337,18 @@ def predict( masks = torch.tensor([]) if len(boxes) > 0: masks = self.predict_sam(image_pil, boxes) - if 1 in masks.shape: + # If masks have 4 dimensions and the second dimension is 1 (e.g., [boxes, 1, height, width]), + # squeeze that dimension to reduce it to 3 dimensions ([boxes, height, width]). + # If boxes = 1, the mask's shape will be [1, height, width] after squeezing. + if masks.ndim == 4 and masks.shape[1] == 1: masks = masks.squeeze(1) if boxes.nelement() == 0: # No "object" instances found print("No objects found in the image.") - return + mask_overlay = np.zeros_like( + image_np[..., 0], dtype=dtype + ) # Create an empty mask overlay + else: # Create an empty image to store the mask overlays mask_overlay = np.zeros_like(