diff --git a/segment_anything/modeling/mask_decoder.py b/segment_anything/modeling/mask_decoder.py index 5d2fdb03d..9f799a5ff 100644 --- a/segment_anything/modeling/mask_decoder.py +++ b/segment_anything/modeling/mask_decoder.py @@ -123,7 +123,11 @@ def predict_masks( tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # Expand per-image data in batch direction to be per-mask - src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + if image_embeddings.shape[0] != tokens.shape[0]: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + src = image_embeddings + src = src + dense_prompt_embeddings pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape