Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
add assert_is_point
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyu2172 committed Feb 27, 2018
1 parent 18f2be3 commit abd3d7a
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainercv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from chainercv.utils.testing import assert_is_detection_link # NOQA
from chainercv.utils.testing import assert_is_image # NOQA
from chainercv.utils.testing import assert_is_label_dataset # NOQA
from chainercv.utils.testing import assert_is_point # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing import ConstantStubLink # NOQA
Expand Down
1 change: 1 addition & 0 deletions chainercv/utils/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from chainercv.utils.testing.assertions import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions import assert_is_image # NOQA
from chainercv.utils.testing.assertions import assert_is_label_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_point # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing.constant_stub_link import ConstantStubLink # NOQA
Expand Down
1 change: 1 addition & 0 deletions chainercv/utils/testing/assertions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from chainercv.utils.testing.assertions.assert_is_detection_link import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image # NOQA
from chainercv.utils.testing.assertions.assert_is_label_dataset import assert_is_label_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_point import assert_is_point # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_link import assert_is_semantic_segmentation_link # NOQA
43 changes: 43 additions & 0 deletions chainercv/utils/testing/assertions/assert_is_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np


def assert_is_point(point, mask=None, size=None):
"""Checks if points satisfy the format.
This function checks if given points satisfy the format and
raises an :class:`AssertionError` when the points violate the convention.
Args:
point (~numpy.ndarray): Points to be checked.
mask (~numpy.ndarray): A mask of the points.
If this is :obj:`None`, all points are regarded as valid.
size (tuple of ints): The size of an image.
If this argument is specified,
the coordinates of valid points are checked to be within the image.
"""

assert isinstance(point, np.ndarray), \
'point must be a numpy.ndarray.'
assert point.dtype == np.float32, \
'The type of point must be numpy.float32.'
assert point.shape[1:] == (2,), \
'The shape of point must be (*, 2).'

if mask is not None:
assert isinstance(mask, np.ndarray), \
'a mask of points must be a numpy.ndarray.'
assert mask.dtype == np.bool, \
'The type of mask must be numpy.bool.'
assert mask.ndim == 1, \
'The dimensionality of a mask must be one.'
assert mask.shape[0] == point.shape[0], \
'The size of the first axis should be the same for ' \
'corresponding point and mask.'
valid_point = point[mask]
else:
valid_point = point

if size is not None:
assert (valid_point >= 0).all() and (valid_point <= size).all(),\
'The coordinates of valid points ' \
'should not exceed the size of image.'
4 changes: 4 additions & 0 deletions docs/source/reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ assert_is_label_dataset
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: assert_is_label_dataset

assert_is_point
~~~~~~~~~~~~~~~
.. autofunction:: assert_is_point

assert_is_semantic_segmentation_dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: assert_is_semantic_segmentation_dataset
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
import unittest

from chainer import testing

from chainercv.utils import assert_is_point


@testing.parameterize(
# no mask and size
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32),
'valid': True},
{'point': ((1., 2.), (4., 8.)),
'valid': False},
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.int32),
'valid': False},
{'point': np.random.uniform(-1, 1, size=(10, 3)).astype(np.float32),
'valid': False},
# use mask, no size
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool),
'valid': True},
{'point': np.random.uniform(-1, 1, size=(4, 2)).astype(np.float32),
'mask': (True, True, True, True),
'valid': False},
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(10,)).astype(np.int32),
'valid': False},
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(10, 2)).astype(np.bool),
'valid': False},
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(9,)).astype(np.bool),
'valid': False},
# use mask and size
{'point': np.random.uniform(0, 32, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool),
'size': (32, 32),
'valid': True},
{'point': np.random.uniform(32, 64, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool),
'size': (32, 32),
'valid': False},
)
class TestAssertIsPoint(unittest.TestCase):

def setUp(self):
if not hasattr(self, 'mask'):
self.mask = None
if not hasattr(self, 'size'):
self.size = None

def test_assert_is_bbox(self):
if self.valid:
assert_is_point(self.point, self.mask, self.size)
else:
with self.assertRaises(AssertionError):
assert_is_point(self.point, self.mask, self.size)


testing.run_module(__name__, __file__)

0 comments on commit abd3d7a

Please sign in to comment.