Skip to content

Commit

Permalink
Merge pull request #39 from kadirnar/add_sahi
Browse files Browse the repository at this point in the history
Add return to load_model function
  • Loading branch information
kadirnar authored Apr 11, 2023
2 parents 15e7549 + 59b273a commit 64396b8
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 41 deletions.
49 changes: 36 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,71 @@ from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor

# For image

autoseg_image = SegAutoMaskPredictor().image_predict(
results = SegAutoMaskPredictor().image_predict(
source="image.jpg",
model_type="vit_l", # vit_l, vit_h, vit_b
points_per_side=16,
points_per_batch=64,
min_area=0,
output_path="output.jpg",
show=True,
save=False,
)

# For video

autoseg_video = SegAutoMaskPredictor().video_predict(
results = SegAutoMaskPredictor().video_predict(
source="video.mp4",
model_type="vit_l", # vit_l, vit_h, vit_b
points_per_side=16,
points_per_batch=64,
min_area=1000,
output_path="output.mp4",
)

# For manuel box and point selection

seg_manual_mask_generator = SegManualMaskPredictor().image_predict(
results = SegManualMaskPredictor().image_predict(
source="image.jpg",
model_type="vit_l", # vit_l, vit_h, vit_b
input_point=[[100, 100], [200, 200]],
input_label=[0, 1],
input_box=[100, 100, 200, 200], # x,y,w,h
input_box=[100, 100, 200, 200], # or [[100, 100, 200, 200], [100, 100, 200, 200]]
multimask_output=False,

random_color=False,
show=True,
save=False,
)
```

# For multi box selection
# SAHI + Segment Anything

seg_manual_mask_generator = SegManualMaskPredictor().image_predict(
source="data/brain.png",
model_type="vit_l",
input_point=None,
input_label=None,
input_box= [[100, 100, 400, 400]],
multimask_output=False,
```python
image_path = "test.jpg"
boxes = sahi_sliced_predict(
image_path=image_path,
detection_model_type="yolov5", #yolov8, detectron2, mmdetection, torchvision
detection_model_path="yolov5l6.pt",
conf_th=0.25,
image_size=1280,
slice_height=256,
slice_width=256,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
)

SahiAutoSegmentation().save_image(
source=image_path,
model_type="vit_b",
input_box=boxes,
multimask_output=False,
random_color=False,
show=True,
save=False,
)
```
<img width="1000" alt="teaser" src="https://github.com/kadirnar/segment-anything-pip/releases/download/v0.5.0/sahi_autoseg.png">

# Extra Features

- [x] Support for Yolov5/8, Detectron2, Mmdetection, Torchvision models
Expand Down
2 changes: 1 addition & 1 deletion metaseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from metaseg.generator.predictor import SamPredictor
from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor

__version__ = "0.4.5"
__version__ = "0.5.0"
43 changes: 21 additions & 22 deletions metaseg/mask_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
from tqdm import tqdm

from metaseg import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from metaseg.utils import download_model, load_box, load_image, load_mask, load_video, multi_boxes
from metaseg.utils import download_model, load_box, load_image, load_mask, load_video, multi_boxes,show_image, save_image


class SegAutoMaskPredictor:
def __init__(self):
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.save = False
self.show = False

def load_model(self, model_type):
if self.model is None:
Expand All @@ -24,7 +22,7 @@ def load_model(self, model_type):

return self.model

def image_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.png"):
def image_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.png", show=False, save=False):
read_image = load_image(source)
model = self.load_model(model_type)
mask_generator = SamAutomaticMaskGenerator(
Expand All @@ -49,15 +47,15 @@ def image_predict(self, source, model_type, points_per_side, points_per_batch, m
mask_image = cv2.add(mask_image, img)

combined_mask = cv2.add(read_image, mask_image)
if self.save:
cv2.imwrite(output_path, combined_mask)

if self.show:
cv2.imshow("Output", combined_mask)
cv2.waitKey(0)
cv2.destroyAllWindows()

return output_path
self.combined_mask = combined_mask
if show:
show_image(combined_mask)
if save:
save_image(output_path=output_path, image=combined_mask)

return masks


def video_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.mp4"):
cap, out = load_video(source, output_path)
Expand Down Expand Up @@ -128,6 +126,9 @@ def image_predict(
input_label=None,
multimask_output=False,
output_path="output.png",
random_color=False,
show=False,
save=False,
):
image = load_image(source)
model = self.load_model(model_type)
Expand All @@ -144,7 +145,7 @@ def image_predict(
multimask_output=False,
)
for mask in masks:
mask_image = load_mask(mask.cpu().numpy(), False)
mask_image = load_mask(mask.cpu().numpy(), random_color)

for box in input_boxes:
image = load_box(box.cpu().numpy(), image)
Expand All @@ -158,16 +159,14 @@ def image_predict(
box=input_boxes,
multimask_output=multimask_output,
)
mask_image = load_mask(masks, True)
mask_image = load_mask(masks, random_color)
image = load_box(input_box, image)

combined_mask = cv2.add(image, mask_image)
if self.save:
cv2.imwrite(output_path, combined_mask)
if save:
save_image(output_path=output_path, image=combined_mask)

if self.show:
cv2.imshow("Output", combined_mask)
cv2.waitKey(0)
cv2.destroyAllWindows()
if show:
show_image(combined_mask)

return output_path
return masks
19 changes: 14 additions & 5 deletions metaseg/sahi_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from metaseg.utils import download_model, load_image, multi_boxes, plt_load_box, plt_load_mask


def sahi_predict(
def sahi_sliced_predict(
image_path,
detection_model_type,
detection_model_path,
Expand Down Expand Up @@ -51,7 +51,7 @@ def sahi_predict(
return boxes


class SahiPredictor:
class SahiAutoSegmentation:
def __init__(self):
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -62,15 +62,21 @@ def load_model(self, model_type):
self.model = sam_model_registry[model_type](checkpoint=self.model_path)
self.model.to(device=self.device)

def save_image(
return self.model

def predict(
self,
source,
model_type,
input_box=None,
input_point=None,
input_label=None,
multimask_output=False,
random_color=False,
save=False,
show=True,
):

read_image = load_image(source)
model = self.load_model(model_type)
predictor = SamPredictor(model)
Expand Down Expand Up @@ -99,8 +105,11 @@ def save_image(
plt.figure(figsize=(10, 10))
plt.imshow(read_image)
for mask in masks:
plt_load_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
plt_load_mask(mask.cpu().numpy(), plt.gca(), random_color=random_color)
for box in input_boxes:
plt_load_box(box.cpu().numpy(), plt.gca())
plt.axis("off")
plt.show()
if save:
plt.savefig("output.png")
if show:
plt.show()
11 changes: 11 additions & 0 deletions metaseg/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,14 @@ def multi_boxes(boxes, predictor, image):
input_boxes = torch.tensor(boxes, device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
return input_boxes, transformed_boxes

def show_image(output_image):
import cv2

cv2.imshow("output", output_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

def save_image(output_image, output_path):
import cv2
cv2.imwrite(output_path, output_image)

0 comments on commit 64396b8

Please sign in to comment.