Skip to content

Commit 44a655b

Browse files
authored
Add details when TestCase.run_layer_test output verification fails. (#21165)
Adds the expected/actual output shapes/dtypes in the failure message. Also greatly simplifies the code by using `keras.tree`.
1 parent 6d52164 commit 44a655b

File tree

1 file changed

+24
-104
lines changed

1 file changed

+24
-104
lines changed

keras/src/testing/test_case.py

+24-104
Original file line numberDiff line numberDiff line change
@@ -355,111 +355,31 @@ def run_build_asserts(layer):
355355

356356
def run_output_asserts(layer, output, eager=False):
357357
if expected_output_shape is not None:
358-
if isinstance(expected_output_shape, tuple) and is_shape_tuple(
359-
expected_output_shape[0]
360-
):
361-
self.assertIsInstance(output, tuple)
362-
self.assertEqual(
363-
len(output),
364-
len(expected_output_shape),
365-
msg="Unexpected number of outputs",
366-
)
367-
output_shape = tuple(v.shape for v in output)
368-
self.assertEqual(
369-
expected_output_shape,
370-
output_shape,
371-
msg="Unexpected output shape",
372-
)
373-
elif isinstance(expected_output_shape, tuple):
374-
self.assertEqual(
375-
expected_output_shape,
376-
output.shape,
377-
msg="Unexpected output shape",
378-
)
379-
elif isinstance(expected_output_shape, dict):
380-
self.assertIsInstance(output, dict)
381-
self.assertEqual(
382-
set(output.keys()),
383-
set(expected_output_shape.keys()),
384-
msg="Unexpected output dict keys",
385-
)
386-
output_shape = {k: v.shape for k, v in output.items()}
387-
self.assertEqual(
388-
expected_output_shape,
389-
output_shape,
390-
msg="Unexpected output shape",
391-
)
392-
elif isinstance(expected_output_shape, list):
393-
self.assertIsInstance(output, list)
394-
self.assertEqual(
395-
len(output),
396-
len(expected_output_shape),
397-
msg="Unexpected number of outputs",
398-
)
399-
output_shape = [v.shape for v in output]
400-
self.assertEqual(
401-
expected_output_shape,
402-
output_shape,
403-
msg="Unexpected output shape",
404-
)
405-
else:
406-
raise ValueError(
407-
"The type of expected_output_shape is not supported"
408-
)
358+
359+
def verify_shape(expected_shape, x):
360+
return expected_shape == x.shape
361+
362+
shapes_match = tree.map_structure_up_to(
363+
output, verify_shape, expected_output_shape, output
364+
)
365+
self.assertTrue(
366+
all(tree.flatten(shapes_match)),
367+
msg=f"Expected output shapes {expected_output_shape} but "
368+
f"received {tree.map_structure(lambda x: x.shape, output)}",
369+
)
409370
if expected_output_dtype is not None:
410-
if isinstance(expected_output_dtype, tuple):
411-
self.assertIsInstance(output, tuple)
412-
self.assertEqual(
413-
len(output),
414-
len(expected_output_dtype),
415-
msg="Unexpected number of outputs",
416-
)
417-
output_dtype = tuple(
418-
backend.standardize_dtype(v.dtype) for v in output
419-
)
420-
self.assertEqual(
421-
expected_output_dtype,
422-
output_dtype,
423-
msg="Unexpected output dtype",
424-
)
425-
elif isinstance(expected_output_dtype, dict):
426-
self.assertIsInstance(output, dict)
427-
self.assertEqual(
428-
set(output.keys()),
429-
set(expected_output_dtype.keys()),
430-
msg="Unexpected output dict keys",
431-
)
432-
output_dtype = {
433-
k: backend.standardize_dtype(v.dtype)
434-
for k, v in output.items()
435-
}
436-
self.assertEqual(
437-
expected_output_dtype,
438-
output_dtype,
439-
msg="Unexpected output dtype",
440-
)
441-
elif isinstance(expected_output_dtype, list):
442-
self.assertIsInstance(output, list)
443-
self.assertEqual(
444-
len(output),
445-
len(expected_output_dtype),
446-
msg="Unexpected number of outputs",
447-
)
448-
output_dtype = [
449-
backend.standardize_dtype(v.dtype) for v in output
450-
]
451-
self.assertEqual(
452-
expected_output_dtype,
453-
output_dtype,
454-
msg="Unexpected output dtype",
455-
)
456-
else:
457-
output_dtype = tree.flatten(output)[0].dtype
458-
self.assertEqual(
459-
expected_output_dtype,
460-
backend.standardize_dtype(output_dtype),
461-
msg="Unexpected output dtype",
462-
)
371+
372+
def verify_dtype(expected_dtype, x):
373+
return expected_dtype == backend.standardize_dtype(x.dtype)
374+
375+
dtypes_match = tree.map_structure(
376+
verify_dtype, expected_output_dtype, output
377+
)
378+
self.assertTrue(
379+
all(tree.flatten(dtypes_match)),
380+
msg=f"Expected output dtypes {expected_output_dtype} but "
381+
f"received {tree.map_structure(lambda x: x.dtype, output)}",
382+
)
463383
if expected_output_sparse:
464384
for x in tree.flatten(output):
465385
self.assertSparse(x)

0 commit comments

Comments
 (0)