Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed size crops #272

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 121 additions & 12 deletions fuel/transformers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,17 @@ def transform_source_batch(self, source, source_name):
if isinstance(source, numpy.ndarray) and source.ndim == 4:
# Hardcoded assumption of (batch, channels, height, width).
# This is what the fast Cython code supports.
out = numpy.empty(source.shape[:2] + self.window_shape,
dtype=source.dtype)
batch_size = source.shape[0]
image_height, image_width = source.shape[2:]
max_h_off = image_height - windowed_height
max_w_off = image_width - windowed_width
if max_h_off < 0 or max_w_off < 0:
raise ValueError("Got ndarray batch with image dimensions {} "
"but requested window shape of {}".format(
source.shape[2:], self.window_shape))
source.shape[2:], self.window_shape))
offsets_w = self.rng.random_integers(0, max_w_off, size=batch_size)
offsets_h = self.rng.random_integers(0, max_h_off, size=batch_size)
window_batch_bchw(source, offsets_h, offsets_w, out)
return out
return self._crop_batch(source, offsets_h, offsets_w)
elif all(isinstance(b, numpy.ndarray) and b.ndim == 3 for b in source):
return [self.transform_source_example(im, source_name)
for im in source]
Expand All @@ -249,16 +246,15 @@ def transform_source_example(self, example, source_name):
self.verify_axis_labels(('channel', 'height', 'width'),
self.data_stream.axis_labels[source_name],
source_name)

windowed_height, windowed_width = self.window_shape
if not isinstance(example, numpy.ndarray) or example.ndim != 3:
raise ValueError("uninterpretable example format; expected "
"ndarray with ndim = 3")
image_height, image_width = example.shape[1:]
if image_height < windowed_height or image_width < windowed_width:
raise ValueError("can't obtain ({}, {}) window from image "
"dimensions ({}, {})".format(
windowed_height, windowed_width,
image_height, image_width))
windowed_height, windowed_width,
image_height, image_width))

if image_height - windowed_height > 0:
off_h = self.rng.random_integers(0, image_height - windowed_height)
else:
Expand All @@ -267,8 +263,121 @@ def transform_source_example(self, example, source_name):
off_w = self.rng.random_integers(0, image_width - windowed_width)
else:
off_w = 0
return example[:, off_h:off_h + windowed_height,
off_w:off_w + windowed_width]
return self._crop_example(example, off_h, off_w)

def _crop_batch(self, source, offsets_h, offsets_w):
out = numpy.empty(source.shape[:2] + self.window_shape,
dtype=source.dtype)
window_batch_bchw(source, offsets_h, offsets_w, out)
return out

def _crop_example(self, example, offset_h, offset_w):
if not isinstance(example, numpy.ndarray) or example.ndim != 3:
raise ValueError("uninterpretable example format; expected "
"ndarray with ndim = 3")
windowed_height, windowed_width = self.window_shape
image_height, image_width = example.shape[1:]
if image_height < offset_h + windowed_height or \
image_width < offset_w + windowed_width:
raise ValueError("can't obtain ({}, {}) from image "
"dimensions ({}, {}) with offset ({}, {})".format(
windowed_height, windowed_width,
image_height, image_width,
offset_h + windowed_height, offset_w + windowed_width))

return example[:,
offset_h:offset_h + windowed_height,
offset_w:offset_w + windowed_width]


class FixedSizeCrop(RandomFixedSizeCrop):
"""Crop images to a fixed window size at a specific position.

Parameters
----------
data_stream : :class:`AbstractDataStream`
The data stream to wrap.
window_shape : tuple
The `(height, width)` tuple representing the size of the output
window.
location : tuple
Location of the crop (height, width) relative to the volume size
(each between 0 and 1, where (0, 0) is the top left corner and (1,
1) the lower right corner, and (.5, .5) the center).

Notes
-----
This transformer expects to act on stream sources which provide one of

* Single images represented as 3-dimensional ndarrays, with layout
`(channel, height, width)`.
* Batches of images represented as lists of 3-dimensional ndarrays,
possibly of different shapes (i.e. images of differing
heights/widths).
* Batches of images represented as 4-dimensional ndarrays, with
layout `(batch, channel, height, width)`.

The format of the stream will be un-altered, i.e. if lists are
yielded by `data_stream` then lists will be yielded by this
transformer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should also be a blank line before the closing set of quotes.


"""
def __init__(self, data_stream, window_shape, location, **kwargs):
if not isinstance(location, (list, tuple)) or len(location) != 2:
raise ValueError('Location must be a tuple or list of length 2 '
'(given {}).'.format(location))
if location[0] < 0 or location[0] > 1 or \
location[1] < 0 or location[1] > 1:
raise ValueError('Location height and width must be between 0 '
'and 1 (given {}).'.format(location))
self.location = location
super(FixedSizeCrop, self).__init__(data_stream, window_shape,
**kwargs)

def transform_source_batch(self, source, source_name):
self.verify_axis_labels(('batch', 'channel', 'height', 'width'),
self.data_stream.axis_labels[source_name],
source_name)
if isinstance(source, list) and all(isinstance(b, numpy.ndarray) and
b.ndim == 3 for b in source):
return [self.transform_source_example(im, source_name)
for im in source]
elif isinstance(source, numpy.ndarray) and \
source.dtype == numpy.object:
return numpy.array([self.transform_source_example(im,
source_name)
for im in source])
elif isinstance(source, numpy.ndarray) and source.ndim == 4:
windowed_height, windowed_width = self.window_shape
image_height, image_width = source.shape[2:]
loc_height, loc_width = self.location
off_h = int(round((image_height - windowed_height) * loc_height))
off_w = int(round((image_width - windowed_width) * loc_width))
if image_height < off_h + windowed_height or \
image_width < off_w + windowed_width:
raise ValueError("can't obtain ({}, {}) window from image "
"dimensions ({}, {}) at location ({}, {})"
.format(windowed_height, windowed_width,
image_height, image_width,
*self.location))
return source[:, :,
off_h:off_h + windowed_height,
off_w:off_w + windowed_width]
else:
raise ValueError("uninterpretable batch format; expected a list "
"of arrays with ndim = 3, or an array with "
"ndim = 4")

def transform_source_example(self, example, source_name):
self.verify_axis_labels(('channel', 'height', 'width'),
self.data_stream.axis_labels[source_name],
source_name)
windowed_height, windowed_width = self.window_shape
loc_height, loc_width = self.location
image_height, image_width = example.shape[1:]
off_h = int(round((image_height - windowed_height) * loc_height))
off_w = int(round((image_width - windowed_width) * loc_width))
return self._crop_example(example, off_h, off_w)


class Random2DRotation(SourcewiseTransformer, ExpectsAxisLabels):
Expand Down
111 changes: 111 additions & 0 deletions tests/transformers/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fuel.transformers.image import (ImagesFromBytes,
MinimumImageDimensions,
RandomFixedSizeCrop,
FixedSizeCrop,
Random2DRotation)


Expand Down Expand Up @@ -435,3 +436,113 @@ def test_random_2D_rotation_batch_stream(self):
out = bstream.transform_source_batch(self.source3, 'source3')
assert_equal(out[0], expected[0])
assert_equal(out[1], expected[1])


class TestFixedSizeCrop(ImageTestingMixin):
def setUp(self):
source1 = numpy.zeros((9, 3, 7, 5), dtype='uint8')
source1[:] = numpy.arange(3 * 7 * 5, dtype='uint8').reshape(3, 7, 5)
shapes = [(5, 8), (6, 8), (5, 6), (5, 5), (6, 4), (7, 4),
(9, 4), (5, 6), (6, 5)]
source2 = []
biggest = 0
num_channels = 2
for shp in shapes:
biggest = max(biggest, shp[0] * shp[1] * 2)
ex = numpy.arange(shp[0] * shp[1] * num_channels).reshape(
(num_channels,) + shp).astype('uint8')
source2.append(ex)
self.source2_biggest = biggest
source3 = numpy.empty((len(shapes),), dtype=object)
for i in range(len(source2)):
source3[i] = source2[i]
axis_labels = {'source1': ('batch', 'channel', 'height', 'width'),
'source2': ('batch', 'channel', 'height', 'width'),
'source3': ('batch', 'channel', 'height', 'width')}
self.dataset = IndexableDataset(OrderedDict([('source1', source1),
('source2', source2),
('source3', source3)]),
axis_labels=axis_labels)
self.common_setup()

def test_ndarray_batch_source(self):
# Make sure that with 4 corner crops we sample everything.
seen_indices = numpy.array([], dtype='uint8')
for loc in [(0, 0), (0, 1), (1, 0), (1, 1)]:
stream = FixedSizeCrop(self.batch_stream, (5, 4),
which_sources=('source1',), location=loc)
# seen indices should only be of that length in after last location
if 3 * 7 * 5 == len(seen_indices):
assert False
for batch in stream.get_epoch_iterator():
assert batch[0].shape[1:] == (3, 5, 4)
assert batch[0].shape[0] in (1, 2)
seen_indices = numpy.union1d(seen_indices, batch[0].flatten())
assert 3 * 7 * 5 == len(seen_indices)

def test_list_batch_source(self):
# Make sure that with 4 corner crops we sample everything.
seen_indices = numpy.array([], dtype='uint8')

for loc in [(0, 0), (0, 1), (1, 0), (1, 1)]:
stream = FixedSizeCrop(self.batch_stream, (5, 4),
which_sources=('source2',), location=loc)
# seen indices should only be of that length in after last location
if self.source2_biggest == len(seen_indices):
assert False
for batch in stream.get_epoch_iterator():
for example in batch[1]:
assert example.shape == (2, 5, 4)
seen_indices = numpy.union1d(seen_indices,
example.flatten())
assert self.source2_biggest == len(seen_indices)

def test_objectarray_batch_source(self):
# Make sure that with 4 corner crops we sample everything.
seen_indices = numpy.array([], dtype='uint8')

for loc in [(0, 0), (0, 1), (1, 0), (1, 1)]:
stream = FixedSizeCrop(self.batch_stream, (5, 4),
which_sources=('source3',), location=loc)
# seen indices should only be of that length in after last location
if self.source2_biggest == len(seen_indices):
assert False
for batch in stream.get_epoch_iterator():
for example in batch[2]:
assert example.shape == (2, 5, 4)
seen_indices = numpy.union1d(seen_indices,
example.flatten())
assert self.source2_biggest == len(seen_indices)

def test_wrong_location_exceptions(self):
assert_raises(ValueError, FixedSizeCrop, self.example_stream, (5, 4),
which_sources=('source2',), location=1)
assert_raises(ValueError, FixedSizeCrop, self.example_stream, (5, 4),
which_sources=('source2',), location=[0, 1, 0])
assert_raises(ValueError, FixedSizeCrop, self.example_stream, (5, 4),
which_sources=('source2',), location=[2, 0])

def test_format_exceptions(self):
estream = FixedSizeCrop(self.example_stream, (5, 4),
which_sources=('source2',), location=[0, 0])
bstream = FixedSizeCrop(self.batch_stream, (5, 4),
which_sources=('source2',), location=[0, 0])
assert_raises(ValueError, estream.transform_source_example,
numpy.empty((5, 6)), 'source2')
assert_raises(ValueError, bstream.transform_source_batch,
[numpy.empty((7, 6))], 'source2')
assert_raises(ValueError, bstream.transform_source_batch,
[numpy.empty((8, 6))], 'source2')

def test_window_too_big_exceptions(self):
stream = FixedSizeCrop(self.example_stream, (5, 4),
which_sources=('source2',), location=[0, 0])

assert_raises(ValueError, stream.transform_source_example,
numpy.empty((3, 4, 2)), 'source2')

bstream = FixedSizeCrop(self.batch_stream, (5, 4),
which_sources=('source1',), location=[0, 0])

assert_raises(ValueError, bstream.transform_source_batch,
numpy.empty((5, 3, 4, 2)), 'source1')