-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsam_explore.py
367 lines (237 loc) · 10.5 KB
/
sam_explore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 20 11:42:08 2024
@author: 14055
"""
import os
os.chdir(r'C:\Users\14055\Desktop\sam_experiments')
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
image = cv2.imread(r'C:\Users\14055\Downloads\IMG_20240221_113827.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
masks = mask_generator.generate(highlighted_image)
masks[0]['segmentation']
print(len(masks))
print(masks[0].keys())
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
area_sum = 0
for i in range(len(masks)):
area_sum =area_sum+ masks[i]['area'] # Corrected from =+ to +=
masks[4]['area']
'''
mask generation params
'''
'''
mask_generator_2 = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.86,
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Requires open-cv to run post-processing
)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()
'''
#------------------------------------------------
#------------------------------------------------
#------------------------------------------------
#-----------find the rectangularity of all the masks and print them, maybe a histogram
#----------for the ones with highest rectangularity - return their indices---------
#------------------merge them with the largest mask around
#--------also suggest a
def calculate_rectangularity(mask_object):
"""
Calculate the rectangularity of an object represented by a mask object.
Args:
- mask_object (dict): A dictionary containing the segmentation mask ('segmentation' key)
and other properties of the object.
Returns:
- rectangularity (float): A measure of how close the object is to a perfect rectangle.
"""
# Use the pre-calculated area of the object
object_area = mask_object['area']
# Extract the bounding box
bbox = mask_object['bbox'] # bbox format is assumed to be [x, y, width, height]
# Calculate the area of the bounding box
bbox_area = bbox[2] * bbox[3]
# Calculate rectangularity as the ratio of object area to bounding box area
rectangularity = object_area / bbox_area
return rectangularity
rect_area_list=[]
for i in range(len(masks)):
mask_object = masks[i] # Assuming 'masks' is your list of mask objects
rectangularity = calculate_rectangularity(mask_object)
rect_area_list.append([rectangularity, mask_object['area']])
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks[2:3])
plt.axis('off')
plt.show()
def merge_overlapping_masks(masks):
"""
Merge overlapping masks based on their bounding boxes.
Args:
- masks (list of dicts): A list of mask objects, each containing a 'segmentation' key with a numpy array,
and a 'bbox' key with the bounding box [x, y, width, height].
Returns:
- merged_masks (list of dicts): A list of merged mask objects.
"""
merged_masks = []
while masks:
current_mask = masks.pop(0)
current_bbox = current_mask['bbox']
current_segmentation = current_mask['segmentation']
overlaps = []
for i, other_mask in enumerate(masks):
other_bbox = other_mask['bbox']
# Check if current_bbox overlaps with other_bbox
if (current_bbox[0] < other_bbox[0] + other_bbox[2] and
current_bbox[0] + current_bbox[2] > other_bbox[0] and
current_bbox[1] < other_bbox[1] + other_bbox[3] and
current_bbox[1] + current_bbox[3] > other_bbox[1]):
overlaps.append(i)
# If there are overlaps, merge them
for i in sorted(overlaps, reverse=True):
# Merge segmentation arrays by logical OR operation
current_segmentation = np.logical_or(current_segmentation, masks[i]['segmentation'])
# Remove the merged mask from the list
del masks[i]
# Recalculate area and bounding box for the merged mask
new_area = np.sum(current_segmentation)
rows = np.any(current_segmentation, axis=1)
cols = np.any(current_segmentation, axis=0)
row_min, row_max = np.where(rows)[0][[0, -1]]
col_min, col_max = np.where(cols)[0][[0, -1]]
new_bbox = [col_min, row_min, col_max - col_min + 1, row_max - row_min + 1]
# Update the current mask object with the merged segmentation, area, and bbox
current_mask['segmentation'] = current_segmentation
current_mask['area'] = new_area
current_mask['bbox'] = new_bbox
# Add the updated mask to the merged_masks list
merged_masks.append(current_mask)
return merged_masks
merged_masks = merge_overlapping_masks(masks)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(merged_masks)
plt.axis('off')
plt.show()
area_sum = 0
for i in range(len(merged_masks)):
area_sum =area_sum+ merged_masks[i]['area'] # Corrected from =+ to +=
#--------------create a boolean mask
def create_binary_mask_image(masks, image_shape):
"""
Create a binary image from a list of masks and the shape of the original image.
Args:
- masks (list of dicts): A list of mask objects, each containing a 'segmentation' key
with a numpy array indicating masked areas.
- image_shape (tuple): The shape of the original image (height, width).
Returns:
- binary_image (np.array): A binary image where pixels in any mask are 1 and all other pixels are 0.
"""
# Initialize a binary image with zeros (shape of the original image)
binary_image = np.zeros(image_shape, dtype=np.uint8)
# Iterate through each mask and mark the masked areas in the binary image
for mask in masks:
segmentation = mask['segmentation']
# Use logical OR to combine the current mask with the binary image
binary_image = np.logical_or(binary_image, segmentation)
return binary_image.astype(np.uint8)
# Example usage:
# Assuming `masks` is your list of mask objects and `input_image_shape` is the shape of your input image
input_image_shape = (image.shape[0], image.shape[1]) # Replace with your actual image shape
binary_mask_image = create_binary_mask_image(masks, input_image_shape)
# Note: If you need the shape of an actual image, you can get it using image.shape if you have the image loaded with a library like OpenCV or PIL.
plt.imshow(binary_mask_image)
#--- display the image only with the green component highlighted - and then
#---- retain only the masks with high green componenets
import cv2
def highlight_green_parts(image, threshold=150, highlight_color=(0, 255, 0)):
"""
Highlight parts of an image with high green color values.
Args:
- image_path (str): Path to the input RGB image.
- threshold (int): Threshold value to consider a pixel as 'high green'. Default is 150.
- highlight_color (tuple): Color used for highlighting. Default is bright green.
Returns:
- highlighted_image (numpy.ndarray): The image with high green parts highlighted.
"""
# Split the image into its color channels
blue_channel, green_channel, red_channel = cv2.split(image)
# Find pixels where the green channel is significantly higher than both
# the red and blue channels
mask = (green_channel > threshold) & (green_channel > red_channel) & (green_channel > blue_channel)
# Create an all-zero image for highlighting
highlight = np.zeros_like(image)
highlight[mask] = highlight_color # Apply the highlight color to the mask
# Combine the highlight with the original image
highlighted_image = cv2.addWeighted(image, 1, highlight, 0.5, 0)
return highlighted_image
def highlight_green_parts(image, threshold=150, highlight_color=(0, 255, 0)):
"""
Highlight parts of an image with high green color values.
Args:
- image_path (str): Path to the input RGB image.
- threshold (int): Threshold value to consider a pixel as 'high green'. Default is 150.
- highlight_color (tuple): Color used for highlighting. Default is bright green.
Returns:
- highlighted_image (numpy.ndarray): The image with high green parts highlighted.
"""
# Split the image into its color channels
blue_channel, green_channel, red_channel = cv2.split(image)
# Find pixels where the green channel is significantly higher than both
# the red and blue channels
mask = (green_channel > threshold) & (green_channel > red_channel) & (green_channel > blue_channel)
# Create an all-zero image for highlighting
highlight = np.zeros_like(image)
highlight[mask] = highlight_color # Apply the highlight color to the mask
# Combine the highlight with the original image
highlighted_image = cv2.addWeighted(image, 1, highlight, 0.5, 0)
return highlighted_image
highlighted_image = highlight_green_parts(image)
# Display the result
plt.imshow(highlighted_image)
cv2.waitKey(0)
cv2.destroyAllWindows()