Skip to content

Commit

Permalink
Removed debug statement and made the test python 3 compatible.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Nagel committed Dec 4, 2015
1 parent 86aab00 commit 4f77868
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
2 changes: 1 addition & 1 deletion fuel/transformers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init__(self, data_stream, window_shape, location, **kwargs):
raise ValueError('Location must be a tuple or list of length 2 '
'(given {}).'.format(location))
if location[0] < 0 or location[0] > 1 or location[1] < 0 or \
location[1] > 1:
location[1] > 1:
raise ValueError('Location height and width must be between 0 '
'and 1 (given {}).'.format(location))
self.location = location
Expand Down
12 changes: 4 additions & 8 deletions tests/transformers/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
from PIL import Image
from picklable_itertools.extras import partition_all
from six.moves import zip

import pyximport
pyximport.install()
from fuel import config
from fuel.datasets.base import IndexableDataset
from fuel.schemes import ShuffledScheme, SequentialExampleScheme
Expand Down Expand Up @@ -322,13 +319,13 @@ def test_ndarray_batch_source(self):
for loc in [(0, 0), (0, 1), (1, 0), (1, 1)]:
stream = FixedSizeCrop(self.batch_stream, (5, 4),
which_sources=('source1',), location=loc)
batch = stream.get_epoch_iterator().next()
assert batch[0].shape[1:] == (3, 5, 4)
assert batch[0].shape[0] in (1, 2)
# seen indices should only be of that length in after last location
if 3 * 7 * 5 == len(seen_indices):
assert False
seen_indices = numpy.union1d(seen_indices, batch[0].flatten())
for batch in stream.get_epoch_iterator():
assert batch[0].shape[1:] == (3, 5, 4)
assert batch[0].shape[0] in (1, 2)
seen_indices = numpy.union1d(seen_indices, batch[0].flatten())
assert 3 * 7 * 5 == len(seen_indices)

def test_list_batch_source(self):
Expand Down Expand Up @@ -360,7 +357,6 @@ def test_objectarray_batch_source(self):
assert False
for batch in stream.get_epoch_iterator():
for example in batch[2]:
print example.shape
assert example.shape == (2, 5, 4)
seen_indices = numpy.union1d(seen_indices,
example.flatten())
Expand Down

0 comments on commit 4f77868

Please sign in to comment.