Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Add assert_is_point #524

Merged
merged 1 commit into from
Feb 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainercv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from chainercv.utils.testing import assert_is_detection_link # NOQA
from chainercv.utils.testing import assert_is_image # NOQA
from chainercv.utils.testing import assert_is_label_dataset # NOQA
from chainercv.utils.testing import assert_is_point # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing import ConstantStubLink # NOQA
Expand Down
1 change: 1 addition & 0 deletions chainercv/utils/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from chainercv.utils.testing.assertions import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions import assert_is_image # NOQA
from chainercv.utils.testing.assertions import assert_is_label_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_point # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing.constant_stub_link import ConstantStubLink # NOQA
Expand Down
1 change: 1 addition & 0 deletions chainercv/utils/testing/assertions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from chainercv.utils.testing.assertions.assert_is_detection_link import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image # NOQA
from chainercv.utils.testing.assertions.assert_is_label_dataset import assert_is_label_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_point import assert_is_point # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_link import assert_is_semantic_segmentation_link # NOQA
43 changes: 43 additions & 0 deletions chainercv/utils/testing/assertions/assert_is_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np


def assert_is_point(point, mask=None, size=None):
"""Checks if points satisfy the format.

This function checks if given points satisfy the format and
raises an :class:`AssertionError` when the points violate the convention.

Args:
point (~numpy.ndarray): Points to be checked.
mask (~numpy.ndarray): A mask of the points.
If this is :obj:`None`, all points are regarded as valid.
size (tuple of ints): The size of an image.
If this argument is specified,
the coordinates of valid points are checked to be within the image.
"""

assert isinstance(point, np.ndarray), \
'point must be a numpy.ndarray.'
assert point.dtype == np.float32, \
'The type of point must be numpy.float32.'
assert point.shape[1:] == (2,), \
'The shape of point must be (*, 2).'

if mask is not None:
assert isinstance(mask, np.ndarray), \
'a mask of points must be a numpy.ndarray.'
assert mask.dtype == np.bool, \
'The type of mask must be numpy.bool.'
assert mask.ndim == 1, \
'The dimensionality of a mask must be one.'
assert mask.shape[0] == point.shape[0], \
'The size of the first axis should be the same for ' \
'corresponding point and mask.'
valid_point = point[mask]
else:
valid_point = point

if size is not None:
assert (valid_point >= 0).all() and (valid_point <= size).all(),\
'The coordinates of valid points ' \
'should not exceed the size of image.'
4 changes: 4 additions & 0 deletions docs/source/reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ assert_is_label_dataset
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: assert_is_label_dataset

assert_is_point
~~~~~~~~~~~~~~~
.. autofunction:: assert_is_point

assert_is_semantic_segmentation_dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: assert_is_semantic_segmentation_dataset
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
import unittest

from chainer import testing

from chainercv.utils import assert_is_point


@testing.parameterize(
# no mask and size
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32),
'valid': True},
{'point': ((1., 2.), (4., 8.)),
'valid': False},
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.int32),
'valid': False},
{'point': np.random.uniform(-1, 1, size=(10, 3)).astype(np.float32),
'valid': False},
# use mask, no size
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool),
'valid': True},
{'point': np.random.uniform(-1, 1, size=(4, 2)).astype(np.float32),
'mask': (True, True, True, True),
'valid': False},
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(10,)).astype(np.int32),
'valid': False},
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(10, 2)).astype(np.bool),
'valid': False},
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(9,)).astype(np.bool),
'valid': False},
# no mask, use size
{'point': np.random.uniform(0, 32, size=(10, 2)).astype(np.float32),
'size': (32, 32),
'valid': True},
{'point': np.random.uniform(32, 64, size=(10, 2)).astype(np.float32),
'size': (32, 32),
'valid': False},
# use mask and size
{'point': np.random.uniform(0, 32, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool),
'size': (32, 32),
'valid': True},
{'point': np.random.uniform(32, 64, size=(10, 2)).astype(np.float32),
'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool),
'size': (32, 32),
'valid': False},
)
class TestAssertIsPoint(unittest.TestCase):

def setUp(self):
if not hasattr(self, 'mask'):
self.mask = None
if not hasattr(self, 'size'):
self.size = None

def test_assert_is_point(self):
if self.valid:
assert_is_point(self.point, self.mask, self.size)
else:
with self.assertRaises(AssertionError):
assert_is_point(self.point, self.mask, self.size)


testing.run_module(__name__, __file__)