Skip to content

Commit

Permalink
Check all output dimensions for compatibility (keras-team#4420)
Browse files Browse the repository at this point in the history
  • Loading branch information
kilotaras authored and fchollet committed Nov 19, 2016
1 parent 04ea01f commit 6b04add
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
19 changes: 10 additions & 9 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,12 @@ def check_array_lengths(X, Y, W):


def check_loss_and_target_compatibility(targets, losses, output_shapes):
assert len(targets) == len(losses) == len(output_shapes)
key_losses = {'mean_square_error',
'binary_crossentropy',
'categorical_crossentropy'}
for y, loss, shape in zip(targets, losses, output_shapes):
if loss.__name__ == 'categorical_crossentropy':
if y.shape[1] == 1:
if y.shape[-1] == 1:
raise Exception('You are passing a target array of shape ' + str(y.shape) +
' while using as loss `categorical_crossentropy`. '
'`categorical_crossentropy` expects '
Expand All @@ -208,13 +207,15 @@ def check_loss_and_target_compatibility(targets, losses, output_shapes):
'Alternatively, you can use the loss function '
'`sparse_categorical_crossentropy` instead, '
'which does expect integer targets.')
if loss.__name__ in key_losses and shape[1] is not None and y.shape[1] != shape[1]:
raise Exception('A target array with shape ' + str(y.shape) +
' was passed for an output of shape ' + str(shape) +
' while using as loss `' + loss.__name__ + '`. '
'This loss expects '
'targets to have the same shape '
'as the output.')
if loss.__name__ in key_losses:
for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
if target_dim is not None and target_dim != out_dim:
raise Exception('A target array with shape ' + str(y.shape) +
' was passed for an output of shape ' + str(shape) +
' while using as loss `' + loss.__name__ + '`. '
'This loss expects '
'targets to have the same shape '
'as the output.')


def collect_metrics(metrics, output_names):
Expand Down
26 changes: 25 additions & 1 deletion tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from keras.layers import Dense, Dropout
from keras.engine.topology import merge, Input
from keras.engine.training import Model
from keras.engine.training import Model, check_loss_and_target_compatibility
from keras.models import Sequential
from keras import backend as K
from keras.utils.test_utils import keras_test
Expand Down Expand Up @@ -202,5 +202,29 @@ def test_trainable_argument():
assert_allclose(out, out_2)


@keras_test
def test_check_not_last_is_one():
a = np.random.random((2, 1, 3))
check_loss_and_target_compatibility([a], [K.categorical_crossentropy], [a.shape])


@keras_test
def test_check_last_is_one():
a = np.random.random((2, 3, 1))
with pytest.raises(Exception) as exc:
check_loss_and_target_compatibility([a], [K.categorical_crossentropy], [a.shape])

assert "You are passing a target array" in str(exc)


@keras_test
def test_check_bad_shape():
a = np.random.random((2, 3, 5))
with pytest.raises(Exception) as exc:
check_loss_and_target_compatibility([a], [K.categorical_crossentropy], [(2, 3, 6)])

assert "targets to have the same shape" in str(exc)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 6b04add

Please sign in to comment.