This repository has been archived by the owner on Jul 2, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 303
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
111 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import numpy as np | ||
|
||
|
||
def assert_is_point(point, mask=None, size=None): | ||
"""Checks if points satisfy the format. | ||
This function checks if given points satisfy the format and | ||
raises an :class:`AssertionError` when the points violate the convention. | ||
Args: | ||
point (~numpy.ndarray): Points to be checked. | ||
mask (~numpy.ndarray): A mask of the points. | ||
If this is :obj:`None`, all points are regarded as valid. | ||
size (tuple of ints): The size of an image. | ||
If this argument is specified, | ||
the coordinates of valid points are checked to be within the image. | ||
""" | ||
|
||
assert isinstance(point, np.ndarray), \ | ||
'point must be a numpy.ndarray.' | ||
assert point.dtype == np.float32, \ | ||
'The type of point must be numpy.float32.' | ||
assert point.shape[1:] == (2,), \ | ||
'The shape of point must be (*, 2).' | ||
|
||
if mask is not None: | ||
assert isinstance(mask, np.ndarray), \ | ||
'a mask of points must be a numpy.ndarray.' | ||
assert mask.dtype == np.bool, \ | ||
'The type of mask must be numpy.bool.' | ||
assert mask.ndim == 1, \ | ||
'The dimensionality of a mask must be one.' | ||
assert mask.shape[0] == point.shape[0], \ | ||
'The size of the first axis should be the same for ' \ | ||
'corresponding point and mask.' | ||
valid_point = point[mask] | ||
else: | ||
valid_point = point | ||
|
||
if size is not None: | ||
assert (valid_point >= 0).all() and (valid_point <= size).all(),\ | ||
'The coordinates of valid points ' \ | ||
'should not exceed the size of image.' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
tests/utils_tests/testing_tests/assertions_tests/test_assert_is_point.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import numpy as np | ||
import unittest | ||
|
||
from chainer import testing | ||
|
||
from chainercv.utils import assert_is_point | ||
|
||
|
||
@testing.parameterize( | ||
# no mask and size | ||
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32), | ||
'valid': True}, | ||
{'point': ((1., 2.), (4., 8.)), | ||
'valid': False}, | ||
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.int32), | ||
'valid': False}, | ||
{'point': np.random.uniform(-1, 1, size=(10, 3)).astype(np.float32), | ||
'valid': False}, | ||
# use mask, no size | ||
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32), | ||
'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool), | ||
'valid': True}, | ||
{'point': np.random.uniform(-1, 1, size=(4, 2)).astype(np.float32), | ||
'mask': (True, True, True, True), | ||
'valid': False}, | ||
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32), | ||
'mask': np.random.randint(0, 2, size=(10,)).astype(np.int32), | ||
'valid': False}, | ||
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32), | ||
'mask': np.random.randint(0, 2, size=(10, 2)).astype(np.bool), | ||
'valid': False}, | ||
{'point': np.random.uniform(-1, 1, size=(10, 2)).astype(np.float32), | ||
'mask': np.random.randint(0, 2, size=(9,)).astype(np.bool), | ||
'valid': False}, | ||
# use mask and size | ||
{'point': np.random.uniform(0, 32, size=(10, 2)).astype(np.float32), | ||
'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool), | ||
'size': (32, 32), | ||
'valid': True}, | ||
{'point': np.random.uniform(32, 64, size=(10, 2)).astype(np.float32), | ||
'mask': np.random.randint(0, 2, size=(10,)).astype(np.bool), | ||
'size': (32, 32), | ||
'valid': False}, | ||
) | ||
class TestAssertIsPoint(unittest.TestCase): | ||
|
||
def setUp(self): | ||
if not hasattr(self, 'mask'): | ||
self.mask = None | ||
if not hasattr(self, 'size'): | ||
self.size = None | ||
|
||
def test_assert_is_bbox(self): | ||
if self.valid: | ||
assert_is_point(self.point, self.mask, self.size) | ||
else: | ||
with self.assertRaises(AssertionError): | ||
assert_is_point(self.point, self.mask, self.size) | ||
|
||
|
||
testing.run_module(__name__, __file__) |