Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
assert_is_point_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyu2172 committed Feb 27, 2018
1 parent 227275b commit 62a80ff
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainercv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from chainercv.utils.testing import assert_is_image # NOQA
from chainercv.utils.testing import assert_is_label_dataset # NOQA
from chainercv.utils.testing import assert_is_point # NOQA
from chainercv.utils.testing import assert_is_point_dataset # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing import ConstantStubLink # NOQA
Expand Down
1 change: 1 addition & 0 deletions chainercv/utils/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from chainercv.utils.testing.assertions import assert_is_image # NOQA
from chainercv.utils.testing.assertions import assert_is_label_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_point # NOQA
from chainercv.utils.testing.assertions import assert_is_point_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing.constant_stub_link import ConstantStubLink # NOQA
Expand Down
1 change: 1 addition & 0 deletions chainercv/utils/testing/assertions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image # NOQA
from chainercv.utils.testing.assertions.assert_is_label_dataset import assert_is_label_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_point import assert_is_point # NOQA
from chainercv.utils.testing.assertions.assert_is_point_dataset import assert_is_point_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_link import assert_is_semantic_segmentation_link # NOQA
59 changes: 59 additions & 0 deletions chainercv/utils/testing/assertions/assert_is_point_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import six

from chainercv.utils.testing.assertions.assert_is_image import assert_is_image
from chainercv.utils.testing.assertions.assert_is_point import assert_is_point


def assert_is_point_dataset(dataset, n_point=None, n_example=None,
no_mask=False):
"""Checks if a dataset satisfies the point dataset API.
This function checks if a given dataset satisfies the point dataset
API or not.
If the dataset does not satifiy the API, this function raises an
:class:`AssertionError`.
Args:
dataset: A dataset to be checked.
n_point (int): The number of expected points per image.
If thsi is :obj:`None`, the number of points per image can be
arbitrary.
n_example (int): The number of examples to be checked.
If this argument is specified, this function picks
examples ramdomly and checks them. Otherwise,
this function checks all examples.
no_mask (bool): If :obj:`True`, we assume that
point mask is always not contained.
If :obj:`False`, point mask may or may not be contained.
"""

assert len(dataset) > 0, 'The length of dataset must be greater than zero.'

if n_example:
for _ in six.moves.range(n_example):
i = np.random.randint(0, len(dataset))
_check_example(dataset[i], n_point, no_mask)
else:
for i in six.moves.range(len(dataset)):
_check_example(dataset[i], n_point, no_mask)


def _check_example(example, n_point=None, no_mask=False):
assert len(example) >= 2, \
'Each example must have at least two elements:' \
'img, point (mask is optional).'

if len(example) == 2 or no_mask:
img, point = example[:2]
mask = None
elif len(example) >= 3:
img, point, mask = example[:3]

assert_is_image(img, color=True)
assert_is_point(point, mask, img.shape[1:])

if n_point is not None:
assert point.shape[0] == n_point, \
'The number of points is different from the expected number.'
4 changes: 4 additions & 0 deletions docs/source/reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ assert_is_point
~~~~~~~~~~~~~~~
.. autofunction:: assert_is_point

assert_is_point_dataset
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: assert_is_point_dataset

assert_is_semantic_segmentation_dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: assert_is_semantic_segmentation_dataset
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import numpy as np
import unittest

from chainer.dataset import DatasetMixin
from chainer import testing

from chainercv.utils import assert_is_point_dataset


class PointDataset(DatasetMixin):

H = 48
W = 64

def __init__(self, n_point_candidates,
return_mask, *options):
self.n_point_candidates = n_point_candidates
self.return_mask = return_mask
self.options = options

def __len__(self):
return 10

def get_example(self, i):
img = np.random.randint(0, 256, size=(3, self.H, self.W))
n_point = np.random.choice(self.n_point_candidates)
point_y = np.random.uniform(0, self.H, size=(n_point,))
point_x = np.random.uniform(0, self.W, size=(n_point,))
point = np.stack((point_y, point_x), axis=1).astype(np.float32)
if self.return_mask:
mask = np.random.randint(0, 2, size=(n_point,)).astype(np.bool)
return (img, point, mask) + self.options
else:
return (img, point) + self.options


class InvalidSampleSizeDataset(PointDataset):

def get_example(self, i):
img = super(
InvalidSampleSizeDataset, self).get_example(i)[0]
return img


class InvalidImageDataset(PointDataset):

def get_example(self, i):
img = super(
InvalidImageDataset, self).get_example(i)[0]
rest = super(
InvalidImageDataset, self).get_example(i)[1:]
return (img[0],) + rest


class InvalidPointDataset(PointDataset):

def get_example(self, i):
img, point = super(InvalidPointDataset, self).get_example(i)[:2]
rest = super(InvalidPointDataset, self).get_example(i)[2:]
point += 1000
return (img, point) + rest


@testing.parameterize(
# No optional Values
{'dataset': PointDataset([10, 15], True), 'valid': True, 'n_point': None},
{'dataset': PointDataset([10, 15], False), 'valid': True, 'n_point': None},
{'dataset': PointDataset([15], True), 'valid': True, 'n_point': 15},
{'dataset': PointDataset([15], False), 'valid': True, 'n_point': 15},
# Invalid n_point
{'dataset': PointDataset([15], True), 'valid': False, 'n_point': 10},
{'dataset': PointDataset([15], False), 'valid': False, 'n_point': 10},
# Return optional values
{'dataset': PointDataset([10, 15], True, 'option'),
'valid': True, 'n_point': None},
{'dataset': PointDataset([10, 15], False, 'option'),
'valid': True, 'n_point': None, 'no_mask': True},
{'dataset': PointDataset([15], True, 'option'),
'valid': True, 'n_point': 15},
{'dataset': PointDataset([15], False, 'option'),
'valid': True, 'n_point': 15, 'no_mask': True},
# Invalid datasets
{'dataset': InvalidSampleSizeDataset([10], True),
'valid': False, 'n_point': None},
{'dataset': InvalidImageDataset([10], True),
'valid': False, 'n_point': None},
{'dataset': InvalidPointDataset([10], True),
'valid': False, 'n_point': None},
)
class TestAssertIsPointDataset(unittest.TestCase):

def setUp(self):
if not hasattr(self, 'no_mask'):
self.no_mask = False

def test_assert_is_point_dataset(self):
if self.valid:
assert_is_point_dataset(
self.dataset, self.n_point, 20, self.no_mask)
else:
with self.assertRaises(AssertionError):
assert_is_point_dataset(
self.dataset, self.n_point, 20, self.no_mask)


testing.run_module(__name__, __file__)

0 comments on commit 62a80ff

Please sign in to comment.