Skip to content

Commit

Permalink
fix: LEAP-932: Having pulled suggestions from ML backend causes valid…
Browse files Browse the repository at this point in the history
…ation on Labeling Interface (#496)

* fix: LEAP-932: Having pulled suggestions from ML backend causes validation on Labeling Interface

* Fix GPU inference, USE_SAM parameters and default checkpoint paths

---------

Co-authored-by: nik <[email protected]>
  • Loading branch information
niklub and nik authored Apr 10, 2024
1 parent 780071c commit 11aaede
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 80 deletions.
4 changes: 2 additions & 2 deletions label_studio_ml/examples/grounding_dino/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ deploy:
Combine the Segment Anything Model with your text input to automatically generate mask predictions!
To do this, set `USE_SAM=True` before running.
To do this, set `USE_SAM=true` before running.

If you want to use a more efficient version of SAM, set `USE_MOBILE_SAM=True` as well.
If you want to use a [more efficient version of SAM](https://github.com/ChaoningZhang/MobileSAM), set `USE_MOBILE_SAM=true`.


## Batching Inputs
Expand Down
123 changes: 50 additions & 73 deletions label_studio_ml/examples/grounding_dino/dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from label_studio_converter import brush
from typing import List, Dict, Optional
from uuid import uuid4
from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.model import LabelStudioMLBase, ModelResponse
from label_studio_tools.core.utils.params import get_bool_env
from label_studio_sdk.objects import PredictionValue
from segment_anything.utils.transforms import ResizeLongestSide

from groundingdino.util.inference import load_model, load_image, predict, annotate
Expand Down Expand Up @@ -86,37 +88,38 @@ def predict_batch(
os.environ.get("LABEL_STUDIO_HOST") or os.environ.get("LABEL_STUDIO_URL")
)

USE_SAM = os.environ.get("USE_SAM", False)
USE_MOBILE_SAM = os.environ.get("USE_MOBILE_SAM", False)
USE_SAM = get_bool_env("USE_SAM", default=False)
USE_MOBILE_SAM = get_bool_env("USE_MOBILE_SAM", default=False)

MOBILESAM_CHECKPOINT = os.environ.get("MOBILESAM_CHECKPOINT", "mobile_sam.pt")
SAM_CHECKPOINT = os.environ.get("SAM_CHECKPOINT", "sam_vit_h_4b8939.pth")


if USE_MOBILE_SAM:
logger.info(f"Using Mobile-SAM with checkpoint {MOBILESAM_CHECKPOINT}")
from mobile_sam import SamPredictor, sam_model_registry

model_checkpoint = MOBILESAM_CHECKPOINT
reg_key = 'vit_t'
elif USE_SAM:
logger.info(f"Using SAM with checkpoint {SAM_CHECKPOINT}")
from segment_anything import SamPredictor, sam_model_registry

model_checkpoint = SAM_CHECKPOINT
reg_key = 'vit_h'
else:
logger.info("Using GroundingDINO without SAM")


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device {DEVICE}")


class DINOBackend(LabelStudioMLBase):

def __init__(self, project_id, **kwargs):
def __init__(self, **kwargs):
super(DINOBackend, self).__init__(**kwargs)

self.label = None

self.from_name, self.to_name, self.value = None, None, None

self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device {self.device}")

Expand All @@ -134,47 +137,29 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
# if there is no context, no interaction has happened yet
return []

self.from_name_r, self.to_name_r, self.value_r = self.get_first_tag_occurence('RectangleLabels', 'Image')
self.from_name_b, self.to_name_b, self.value_b = self.get_first_tag_occurence('BrushLabels', 'Image')

TEXT_PROMPT = context['result'][0]['value']['text'][0]


x = TEXT_PROMPT.split("_")
from_name_r, to_name_r, value = self.get_first_tag_occurence('RectangleLabels', 'Image')
from_name_b, to_name_b, _ = self.get_first_tag_occurence('BrushLabels', 'Image')

if len(x) > 1:
self.label = x[1]
self.prompt = x[0]
else:
self.label = x[0]
self.prompt = x[0]
text_prompt = context['result'][0]['value']['text'][0]

print(f"the label is {self.label} and prompt {self.prompt} and {self.from_name_r} and {self.from_name_b}")

# self.label = TEXT_PROMPT.split("_")[0] # make sure that using as text prompt allows you to label it a certain way

if self.use_sam == 'True':
self.use_sam=True
if self.use_sam == 'False':
self.use_sam = False
if self.use_ms == 'True':
self.use_ms = True
if self.use_ms == 'False':
self.use_ms = False
logger.info(f"the prompt is {text_prompt} and {from_name_r} and {from_name_b}")

final_predictions = []
if len(tasks) > 1:
final_predictions = self.multiple_tasks(tasks)
final_predictions = self.multiple_tasks(
tasks, text_prompt, from_name_r, to_name_r, from_name_b, to_name_b, value)
elif len(tasks) == 1:
final_predictions = self.one_task(tasks[0])
final_predictions = self.one_task(
tasks[0], text_prompt, from_name_r, to_name_r, from_name_b, to_name_b, value)

return final_predictions

def one_task(self, task):
def one_task(self, task, prompt, from_name_r, to_name_r, from_name_b, to_name_b, value):
all_points = []
all_scores = []
all_lengths = []
predictions = []
raw_img_path = task['data']['image']
raw_img_path = task['data'][value]

try:
img_path = self.get_local_path(
Expand All @@ -192,7 +177,7 @@ def one_task(self, task):
boxes, logits, _ = predict(
model=groundingdino_model,
image=img,
caption=self.prompt,
caption=prompt,
box_threshold=float(BOX_THRESHOLD),
text_threshold=float(TEXT_THRESHOLD),
device=DEVICE
Expand All @@ -210,19 +195,21 @@ def one_task(self, task):
all_lengths.append((H, W))

if self.use_ms or self.use_sam:
predictions.append(self.get_sam_results(img_path, all_points, all_lengths))
# get <BrushLabels> results
predictions.append(self.get_sam_results(img_path, all_points, all_lengths, from_name_b, to_name_b))
else:
predictions.append(self.get_results(all_points, all_scores, all_lengths))
# get <RectangleLabels> results
predictions.append(self.get_results(all_points, all_scores, all_lengths, from_name_r, to_name_r))

return predictions

def multiple_tasks(self, tasks):
def multiple_tasks(self, tasks, prompt, from_name_r, to_name_r, from_name_b, to_name_b, value):

# first getting all the image paths
image_paths = []

for task in tasks:
raw_img_path = task['data']['image']
raw_img_path = task['data'][value]

try:
img_path = self.get_local_path(
Expand All @@ -237,7 +224,7 @@ def multiple_tasks(self, tasks):

image_paths.append(img_path)

boxes, logits, lengths = self.batch_dino(image_paths)
boxes, logits, lengths = self.batch_dino(image_paths, prompt)

box_by_task = []
for (box_task, (H, W)) in zip(boxes, lengths):
Expand All @@ -248,7 +235,7 @@ def multiple_tasks(self, tasks):

if self.use_ms or self.use_sam:
batched_output = self.batch_sam(input_boxes_list=box_by_task, image_paths=image_paths)
predictions = self.get_batched_sam_results(batched_output)
predictions = self.get_batched_sam_results(batched_output, from_name_b, to_name_b)

else:
predictions = []
Expand All @@ -265,12 +252,12 @@ def multiple_tasks(self, tasks):
all_scores.append(logit)
all_lengths.append((H, W)) # figure out how to get this

predictions.append(self.get_results(all_points, all_scores, all_lengths))
predictions.append(self.get_results(all_points, all_scores, all_lengths, from_name_r, to_name_r))

return predictions

# make sure you use new github repo when predicting in batch
def batch_dino(self, image_paths):
def batch_dino(self, image_paths, prompt):
# text prompt is same as self.label
loaded_images = []
lengths = []
Expand All @@ -288,9 +275,9 @@ def batch_dino(self, image_paths):
boxes, logits, _ = predict_batch(
model=groundingdino_model,
images=images,
caption=self.prompt, # text prompt is same as self.label
caption=prompt, # text prompt is same as self.label
box_threshold=float(BOX_THRESHOLD),
text_threshold = float(TEXT_THRESHOLD),
text_threshold=float(TEXT_THRESHOLD),
device=self.device
)

Expand All @@ -301,7 +288,7 @@ def batch_dino(self, image_paths):
boxes, logits, _ = predict(
model=groundingdino_model,
image=img,
caption=self.prompt,
caption=prompt,
box_threshold=float(BOX_THRESHOLD),
text_threshold=float(TEXT_THRESHOLD),
device=DEVICE
Expand All @@ -314,9 +301,6 @@ def batch_dino(self, image_paths):

return boxes, logits, lengths




def batch_sam(self, input_boxes_list, image_paths):

resize_transform = ResizeLongestSide(self.sam.image_encoder.img_size)
Expand All @@ -327,9 +311,7 @@ def prepare_image(image, transform, device):
image = torch.as_tensor(image, device=device.device)
return image.permute(2, 0, 1).contiguous()


batched_input = []
lengths = []
for input_box, path in zip(input_boxes_list, image_paths):
image = cv2.imread(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
Expand All @@ -343,30 +325,27 @@ def prepare_image(image, transform, device):

return batched_output

def get_batched_sam_results(self, batched_output):
def get_batched_sam_results(self, batched_output, from_name_b, to_name_b):

predictions = []

for output in batched_output:
masks = output['masks']
masks = masks[:, 0, :, :].cpu().numpy().astype(np.uint8)


probs = output['iou_predictions'].cpu().numpy()


num_masks = masks.shape[0]
height = masks.shape[-2]
width = masks.shape[-1]

lengths = [(height, width)] * num_masks

predictions.append(self.sam_predictions(masks, probs, lengths))
predictions.append(self.sam_predictions(masks, probs, lengths, from_name_b, to_name_b))

return predictions


def get_results(self, all_points, all_scores, all_lengths):
def get_results(self, all_points, all_scores, all_lengths, from_name_r, to_name_r):

results = []

Expand All @@ -379,14 +358,13 @@ def get_results(self, all_points, all_scores, all_lengths):
#TODO: add model version
results.append({
'id': label_id,
'from_name': self.from_name_r,
'to_name': self.to_name_r,
'from_name': from_name_r,
'to_name': to_name_r,
'original_width': width,
'original_height': height,
'image_rotation': 0,
'value': {
'rotation': 0,
'rectanglelabels': [self.label],
'width': (points[2] - points[0]) / width * 100,
'height': (points[3] - points[1]) / height * 100,
'x': points[0] / width * 100,
Expand All @@ -397,7 +375,6 @@ def get_results(self, all_points, all_scores, all_lengths):
'readonly': False
})


return {
'result': results
}
Expand All @@ -406,16 +383,17 @@ def get_sam_results(
self,
img_path,
input_boxes,
lengths
lengths,
from_name_b,
to_name_b
):
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image)

input_boxes = torch.from_numpy(np.array(input_boxes))


transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2]).to(self.device)
masks, probs, _ = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
Expand All @@ -426,10 +404,10 @@ def get_sam_results(
masks = masks[:, 0, :, :].cpu().numpy().astype(np.uint8)
probs = probs.cpu().numpy()

return self.sam_predictions(masks, probs, lengths)
return self.sam_predictions(masks, probs, lengths, from_name_b, to_name_b)

# takes straight masks and returns predictions
def sam_predictions(self, masks, probs, lengths):
def sam_predictions(self, masks, probs, lengths, from_name_b, to_name_b):

results = []

Expand All @@ -444,15 +422,14 @@ def sam_predictions(self, masks, probs, lengths):

results.append({
'id': label_id,
'from_name': self.from_name_b,
'to_name': self.to_name_b,
'from_name': from_name_b,
'to_name': to_name_b,
'original_width': width,
'original_height': height,
'image_rotation': 0,
'value': {
'format': 'rle',
'rle': rle,
'brushlabels': [self.label],
'rle': rle
},
'score': float(prob[0]),
'type': 'brushlabels',
Expand Down
7 changes: 2 additions & 5 deletions label_studio_ml/examples/grounding_dino/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ services:
- LABEL_STUDIO_ACCESS_TOKEN=

# use these if you want to use segment anything instead of bounding box predictions from input text prompts
- USE_SAM=False # if you want to automatically generate segment anything model predictions
- USE_MOBILE_SAM=False # whether you want to use a more efficient, yet a bit less accurate, version of the segment anything model

- MOBILE_SAM_CHECKPOINT= # in case you want to point to another directory than where Docker automatically places the model
- SAM_CHECKPOINT= # if you want to use regular SAM and point to another directory than where Docker automatically places the model
- USE_SAM=false # if you want to automatically generate segment anything model predictions
- USE_MOBILE_SAM=false # whether you want to use a more efficient, yet a bit less accurate, version of the segment anything model

- BOX_THRESHOLD=0.30
- TEXT_THRESHOLD=0.25
Expand Down

0 comments on commit 11aaede

Please sign in to comment.