From abd3d7acfa13683434299841a7f26b87761dd701 Mon Sep 17 00:00:00 2001 From: Yusuke Niitani Date: Tue, 27 Feb 2018 13:53:29 +0900 Subject: [PATCH] add assert_is_point --- chainercv/utils/__init__.py | 1 + chainercv/utils/testing/__init__.py | 1 + .../utils/testing/assertions/__init__.py | 1 + .../testing/assertions/assert_is_point.py | 43 +++++++++++++ docs/source/reference/utils.rst | 4 ++ .../assertions_tests/test_assert_is_point.py | 61 +++++++++++++++++++ 6 files changed, 111 insertions(+) create mode 100644 chainercv/utils/testing/assertions/assert_is_point.py create mode 100644 tests/utils_tests/testing_tests/assertions_tests/test_assert_is_point.py diff --git a/chainercv/utils/__init__.py b/chainercv/utils/__init__.py index 47f48e3e77..6af87d74b1 100644 --- a/chainercv/utils/__init__.py +++ b/chainercv/utils/__init__.py @@ -14,6 +14,7 @@ from chainercv.utils.testing import assert_is_detection_link # NOQA from chainercv.utils.testing import assert_is_image # NOQA from chainercv.utils.testing import assert_is_label_dataset # NOQA +from chainercv.utils.testing import assert_is_point # NOQA from chainercv.utils.testing import assert_is_semantic_segmentation_dataset # NOQA from chainercv.utils.testing import assert_is_semantic_segmentation_link # NOQA from chainercv.utils.testing import ConstantStubLink # NOQA diff --git a/chainercv/utils/testing/__init__.py b/chainercv/utils/testing/__init__.py index 3c636efe1b..91582f1085 100644 --- a/chainercv/utils/testing/__init__.py +++ b/chainercv/utils/testing/__init__.py @@ -3,6 +3,7 @@ from chainercv.utils.testing.assertions import assert_is_detection_link # NOQA from chainercv.utils.testing.assertions import assert_is_image # NOQA from chainercv.utils.testing.assertions import assert_is_label_dataset # NOQA +from chainercv.utils.testing.assertions import assert_is_point # NOQA from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_dataset # NOQA from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_link # NOQA from chainercv.utils.testing.constant_stub_link import ConstantStubLink # NOQA diff --git a/chainercv/utils/testing/assertions/__init__.py b/chainercv/utils/testing/assertions/__init__.py index 11f6420563..41f20f3ff7 100644 --- a/chainercv/utils/testing/assertions/__init__.py +++ b/chainercv/utils/testing/assertions/__init__.py @@ -3,5 +3,6 @@ from chainercv.utils.testing.assertions.assert_is_detection_link import assert_is_detection_link # NOQA from chainercv.utils.testing.assertions.assert_is_image import assert_is_image # NOQA from chainercv.utils.testing.assertions.assert_is_label_dataset import assert_is_label_dataset # NOQA +from chainercv.utils.testing.assertions.assert_is_point import assert_is_point # NOQA from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset import assert_is_semantic_segmentation_dataset # NOQA from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_link import assert_is_semantic_segmentation_link # NOQA diff --git a/chainercv/utils/testing/assertions/assert_is_point.py b/chainercv/utils/testing/assertions/assert_is_point.py new file mode 100644 index 0000000000..2596da6605 --- /dev/null +++ b/chainercv/utils/testing/assertions/assert_is_point.py @@ -0,0 +1,43 @@ +import numpy as np + + +def assert_is_point(point, mask=None, size=None): + """Checks if points satisfy the format. + + This function checks if given points satisfy the format and + raises an :class:`AssertionError` when the points violate the convention. + + Args: + point (~numpy.ndarray): Points to be checked. + mask (~numpy.ndarray): A mask of the points. + If this is :obj:`None`, all points are regarded as valid. + size (tuple of ints): The size of an image. + If this argument is specified, + the coordinates of valid points are checked to be within the image. + """ + + assert isinstance(point, np.ndarray), \ + 'point must be a numpy.ndarray.' + assert point.dtype == np.float32, \ + 'The type of point must be numpy.float32.' + assert point.shape[1:] == (2,), \ + 'The shape of point must be (*, 2).' + + if mask is not None: + assert isinstance(mask, np.ndarray), \ + 'a mask of points must be a numpy.ndarray.' + assert mask.dtype == np.bool, \ + 'The type of mask must be numpy.bool.' + assert mask.ndim == 1, \ + 'The dimensionality of a mask must be one.' + assert mask.shape[0] == point.shape[0], \ + 'The size of the first axis should be the same for ' \ + 'corresponding point and mask.' + valid_point = point[mask] + else: + valid_point = point + + if size is not None: + assert (valid_point >= 0).all() and (valid_point <= size).all(),\ + 'The coordinates of valid points ' \ + 'should not exceed the size of image.' diff --git a/docs/source/reference/utils.rst b/docs/source/reference/utils.rst index 3b952002b4..5a828a3366 100644 --- a/docs/source/reference/utils.rst +++ b/docs/source/reference/utils.rst @@ -87,6 +87,10 @@ assert_is_label_dataset ~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: assert_is_label_dataset +assert_is_point +~~~~~~~~~~~~~~~ +.. autofunction:: assert_is_point + assert_is_semantic_segmentation_dataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: assert_is_semantic_segmentation_dataset diff --git a/tests/utils_tests/testing_tests/assertions_tests/test_assert_is_point.py b/tests/utils_tests/testing_tests/assertions_tests/test_assert_is_point.py new file mode 100644 index 0000000000..2a058fd6dd --- /dev/null +++ b/tests/utils_tests/testing_tests/assertions_tests/test_assert_is_point.py @@ -0,0 +1,61 @@ +import numpy as np +import unittest + +from chainer import testing + +from chainercv.utils import assert_is_point + + +@testing.parameterize( + # no mask and size + {'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32), + 'valid': True}, + {'point': ((1., 2.), (4., 8.)), + 'valid': False}, + {'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.int32), + 'valid': False}, + {'point': np.random.uniform(-1, 1, size=(10, 3)).astype(np.float32), + 'valid': False}, + # use mask, no size + {'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32), + 'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool), + 'valid': True}, + {'point': np.random.uniform(-1, 1, size=(4, 2)).astype(np.float32), + 'mask': (True, True, True, True), + 'valid': False}, + {'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32), + 'mask': np.random.randint(0, 2, size=(10,)).astype(np.int32), + 'valid': False}, + {'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32), + 'mask': np.random.randint(0, 2, size=(10, 2)).astype(np.bool), + 'valid': False}, + {'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32), + 'mask': np.random.randint(0, 2, size=(9,)).astype(np.bool), + 'valid': False}, + # use mask and size + {'point': np.random.uniform(0, 32, size=(10, 2)).astype(np.float32), + 'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool), + 'size': (32, 32), + 'valid': True}, + {'point': np.random.uniform(32, 64, size=(10, 2)).astype(np.float32), + 'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool), + 'size': (32, 32), + 'valid': False}, +) +class TestAssertIsPoint(unittest.TestCase): + + def setUp(self): + if not hasattr(self, 'mask'): + self.mask = None + if not hasattr(self, 'size'): + self.size = None + + def test_assert_is_bbox(self): + if self.valid: + assert_is_point(self.point, self.mask, self.size) + else: + with self.assertRaises(AssertionError): + assert_is_point(self.point, self.mask, self.size) + + +testing.run_module(__name__, __file__)