Skip to content

Commit

Permalink
fix: correct dims in case of automatic mask generation
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Aug 4, 2024
1 parent 9905e92 commit 995edbc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["flit_core >=3.2,<4"]
build-backend = "flit_core.buildapi"

[project]
description = "fork of the official SAMv2 implementation with cpu support"
description = "CPU compatible fork of the official SAMv2 implementation"
name = "samv2"
version = "0.0.4"
authors = [{ name = "Saurav Maheshkar", email = "[email protected]" }]
Expand Down
2 changes: 1 addition & 1 deletion sam2/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def show_masks(
output_image = Image.new(
mode="RGBA",
size=(
masks[0]["segmentation"].shape[0],
masks[0]["segmentation"].shape[1],
masks[0]["segmentation"].shape[0],
),
color=(0, 0, 0),
)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_image_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,14 @@ def test_prompt_mode(load_image, image_predictor, point, label, box, num_masks)
@pytest.mark.full
def test_mask_generation(load_image, mask_generator) -> None:
masks = mask_generator.generate(load_image)

output_mask = show_masks(
image=load_image,
masks=masks, # type: ignore
scores=None,
only_best=False,
autogenerated_mask=True,
)

assert len(masks) > 0
assert isinstance(output_mask, Image.Image)

0 comments on commit 995edbc

Please sign in to comment.