This repository has been archived by the owner on Jul 2, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 303
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
172 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
chainercv/utils/testing/assertions/assert_is_point_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import numpy as np | ||
import six | ||
|
||
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image | ||
from chainercv.utils.testing.assertions.assert_is_point import assert_is_point | ||
|
||
|
||
def assert_is_point_dataset(dataset, n_point=None, n_example=None, | ||
no_mask=False): | ||
"""Checks if a dataset satisfies the point dataset API. | ||
This function checks if a given dataset satisfies the point dataset | ||
API or not. | ||
If the dataset does not satifiy the API, this function raises an | ||
:class:`AssertionError`. | ||
Args: | ||
dataset: A dataset to be checked. | ||
n_point (int): The number of expected points per image. | ||
If thsi is :obj:`None`, the number of points per image can be | ||
arbitrary. | ||
n_example (int): The number of examples to be checked. | ||
If this argument is specified, this function picks | ||
examples ramdomly and checks them. Otherwise, | ||
this function checks all examples. | ||
no_mask (bool): If :obj:`True`, we assume that | ||
point mask is always not contained. | ||
If :obj:`False`, point mask may or may not be contained. | ||
""" | ||
|
||
assert len(dataset) > 0, 'The length of dataset must be greater than zero.' | ||
|
||
if n_example: | ||
for _ in six.moves.range(n_example): | ||
i = np.random.randint(0, len(dataset)) | ||
_check_example(dataset[i], n_point, no_mask) | ||
else: | ||
for i in six.moves.range(len(dataset)): | ||
_check_example(dataset[i], n_point, no_mask) | ||
|
||
|
||
def _check_example(example, n_point=None, no_mask=False): | ||
assert len(example) >= 2, \ | ||
'Each example must have at least two elements:' \ | ||
'img, point (mask is optional).' | ||
|
||
if len(example) == 2 or no_mask: | ||
img, point = example[:2] | ||
mask = None | ||
elif len(example) >= 3: | ||
img, point, mask = example[:3] | ||
|
||
assert_is_image(img, color=True) | ||
assert_is_point(point, mask, img.shape[1:]) | ||
|
||
if n_point is not None: | ||
assert point.shape[0] == n_point, \ | ||
'The number of points is different from the expected number.' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
tests/utils_tests/testing_tests/assertions_tests/test_assert_is_point_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import numpy as np | ||
import unittest | ||
|
||
from chainer.dataset import DatasetMixin | ||
from chainer import testing | ||
|
||
from chainercv.utils import assert_is_point_dataset | ||
|
||
|
||
class PointDataset(DatasetMixin): | ||
|
||
H = 48 | ||
W = 64 | ||
|
||
def __init__(self, n_point_candidates, | ||
return_mask, *options): | ||
self.n_point_candidates = n_point_candidates | ||
self.return_mask = return_mask | ||
self.options = options | ||
|
||
def __len__(self): | ||
return 10 | ||
|
||
def get_example(self, i): | ||
img = np.random.randint(0, 256, size=(3, self.H, self.W)) | ||
n_point = np.random.choice(self.n_point_candidates) | ||
point_y = np.random.uniform(0, self.H, size=(n_point,)) | ||
point_x = np.random.uniform(0, self.W, size=(n_point,)) | ||
point = np.stack((point_y, point_x), axis=1).astype(np.float32) | ||
if self.return_mask: | ||
mask = np.random.randint(0, 2, size=(n_point,)).astype(np.bool) | ||
return (img, point, mask) + self.options | ||
else: | ||
return (img, point) + self.options | ||
|
||
|
||
class InvalidSampleSizeDataset(PointDataset): | ||
|
||
def get_example(self, i): | ||
img = super( | ||
InvalidSampleSizeDataset, self).get_example(i)[0] | ||
return img | ||
|
||
|
||
class InvalidImageDataset(PointDataset): | ||
|
||
def get_example(self, i): | ||
img = super( | ||
InvalidImageDataset, self).get_example(i)[0] | ||
rest = super( | ||
InvalidImageDataset, self).get_example(i)[1:] | ||
return (img[0],) + rest | ||
|
||
|
||
class InvalidPointDataset(PointDataset): | ||
|
||
def get_example(self, i): | ||
img, point = super(InvalidPointDataset, self).get_example(i)[:2] | ||
rest = super(InvalidPointDataset, self).get_example(i)[2:] | ||
point += 1000 | ||
return (img, point) + rest | ||
|
||
|
||
@testing.parameterize( | ||
# No optional Values | ||
{'dataset': PointDataset([10, 15], True), 'valid': True, 'n_point': None}, | ||
{'dataset': PointDataset([10, 15], False), 'valid': True, 'n_point': None}, | ||
{'dataset': PointDataset([15], True), 'valid': True, 'n_point': 15}, | ||
{'dataset': PointDataset([15], False), 'valid': True, 'n_point': 15}, | ||
# Invalid n_point | ||
{'dataset': PointDataset([15], True), 'valid': False, 'n_point': 10}, | ||
{'dataset': PointDataset([15], False), 'valid': False, 'n_point': 10}, | ||
# Return optional values | ||
{'dataset': PointDataset([10, 15], True, 'option'), | ||
'valid': True, 'n_point': None}, | ||
{'dataset': PointDataset([10, 15], False, 'option'), | ||
'valid': True, 'n_point': None, 'no_mask': True}, | ||
{'dataset': PointDataset([15], True, 'option'), | ||
'valid': True, 'n_point': 15}, | ||
{'dataset': PointDataset([15], False, 'option'), | ||
'valid': True, 'n_point': 15, 'no_mask': True}, | ||
# Invalid datasets | ||
{'dataset': InvalidSampleSizeDataset([10], True), | ||
'valid': False, 'n_point': None}, | ||
{'dataset': InvalidImageDataset([10], True), | ||
'valid': False, 'n_point': None}, | ||
{'dataset': InvalidPointDataset([10], True), | ||
'valid': False, 'n_point': None}, | ||
) | ||
class TestAssertIsPointDataset(unittest.TestCase): | ||
|
||
def setUp(self): | ||
if not hasattr(self, 'no_mask'): | ||
self.no_mask = False | ||
|
||
def test_assert_is_point_dataset(self): | ||
if self.valid: | ||
assert_is_point_dataset( | ||
self.dataset, self.n_point, 20, self.no_mask) | ||
else: | ||
with self.assertRaises(AssertionError): | ||
assert_is_point_dataset( | ||
self.dataset, self.n_point, 20, self.no_mask) | ||
|
||
|
||
testing.run_module(__name__, __file__) |