-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathimage_segmenter_test.py
120 lines (95 loc) · 4.63 KB
/
image_segmenter_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit test of image segmentation using ImageSegmenter wrapper."""
import unittest
import cv2
import utils
from image_segmenter import ImageSegmenter
from image_segmenter import ImageSegmenterOptions
from image_segmenter import OutputType
import numpy as np
_MODEL_FILE = 'deeplabv3.tflite'
_IMAGE_FILE = 'test_data/input_image.jpeg'
_GROUND_TRUTH_IMAGE_FILE = 'test_data/ground_truth_segmentation.png'
_GROUND_TRUTH_LABEL_FILE = 'test_data/ground_truth_label.txt'
_MATCH_PIXELS_THRESHOLD = 0.01
# def _get_pixels(input: np.ndarray) -> List[int]:
# """Flatter numpy array into a list pixels.
# Args:
# input: A numpy array.
# Returns:
# A list pixels.
# """
# return input.flatten().tolist()
class ImageSegmenterTest(unittest.TestCase):
def _load_ground_truth(self):
"""Load ground truth segmentation result from the image and CSV file."""
self._ground_truth_segmentation = cv2.imread(_GROUND_TRUTH_IMAGE_FILE)
self._ground_truth_labels = []
with open(_GROUND_TRUTH_LABEL_FILE) as f:
self._ground_truth_labels = f.read().splitlines()
def setUp(self):
"""Initialize the shared variables."""
super().setUp()
self._load_ground_truth()
image = cv2.imread(_IMAGE_FILE)
self.image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Run segmentation with the TFLite model in CATEGORY_MASK mode.
segmenter = ImageSegmenter(_MODEL_FILE)
result = segmenter.segment(self.image)
self._seg_map_img, self._found_labels = utils.segmentation_map_to_image(result)
self._category_mask = result.masks
def test_segmentation_category_mask(self):
"""Check if category mask match with ground truth."""
result_pixels = self._seg_map_img.flatten()
ground_truth_pixels = self._ground_truth_segmentation.flatten()
self.assertEqual(len(result_pixels), len(ground_truth_pixels),
"Segmentation mask must have the same size as ground truth.")
inconsistent_pixels = [1 for idx in range(len(result_pixels))
if result_pixels[idx] != ground_truth_pixels[idx]]
self.assertLessEqual(len(inconsistent_pixels) / len(result_pixels), _MATCH_PIXELS_THRESHOLD,
"Segmentation mask value must be the same size as ground truth.")
def test_segmentation_confidence_mask(self):
"""Check if confidence mask matches with category mask."""
# Run segmentation with the TFLite model in CONFIDENCE_MASK mode.
options = ImageSegmenterOptions(output_type=OutputType.CONFIDENCE_MASK)
segmenter = ImageSegmenter(_MODEL_FILE, options)
result = segmenter.segment(self.image)
# Check if confidence mask shape is correct.
self.assertEqual(result.masks.shape[2],
len(result.colored_labels),
'3rd dimension of confidence mask must match with number of categories.')
calculated_category_mask = np.argmax(result.masks, axis=2)
self.assertListEqual(calculated_category_mask.tolist(), self._category_mask.tolist())
def test_labels(self):
"""Check if detected labels match with ground truth labels."""
result_labels = [colored_label.label for colored_label in self._found_labels]
self.assertEqual(result_labels, self._ground_truth_labels)
def _create_ground_truth_data(self,
output_image_file: str = _GROUND_TRUTH_IMAGE_FILE,
output_label_file: str = _GROUND_TRUTH_LABEL_FILE) -> None:
"""A util function to generate the ground truth result.
Args:
output_image_file: Path to save the segmentation map of output model.
output_label_file: Path to save the label list of output model.
"""
# Initialize the image segmentation model
segmenter = ImageSegmenter(_MODEL_FILE)
result = segmenter.segment(self.image)
seg_map_img, found_labels = utils.segmentation_map_to_image(result)
cv2.imwrite(output_image_file, seg_map_img)
with open(output_label_file, 'w') as f:
f.writelines('\n'.join([color_label.label for color_label in found_labels]))
if __name__ == '__main__':
unittest.main()