@@ -355,111 +355,31 @@ def run_build_asserts(layer):
355
355
356
356
def run_output_asserts (layer , output , eager = False ):
357
357
if expected_output_shape is not None :
358
- if isinstance (expected_output_shape , tuple ) and is_shape_tuple (
359
- expected_output_shape [0 ]
360
- ):
361
- self .assertIsInstance (output , tuple )
362
- self .assertEqual (
363
- len (output ),
364
- len (expected_output_shape ),
365
- msg = "Unexpected number of outputs" ,
366
- )
367
- output_shape = tuple (v .shape for v in output )
368
- self .assertEqual (
369
- expected_output_shape ,
370
- output_shape ,
371
- msg = "Unexpected output shape" ,
372
- )
373
- elif isinstance (expected_output_shape , tuple ):
374
- self .assertEqual (
375
- expected_output_shape ,
376
- output .shape ,
377
- msg = "Unexpected output shape" ,
378
- )
379
- elif isinstance (expected_output_shape , dict ):
380
- self .assertIsInstance (output , dict )
381
- self .assertEqual (
382
- set (output .keys ()),
383
- set (expected_output_shape .keys ()),
384
- msg = "Unexpected output dict keys" ,
385
- )
386
- output_shape = {k : v .shape for k , v in output .items ()}
387
- self .assertEqual (
388
- expected_output_shape ,
389
- output_shape ,
390
- msg = "Unexpected output shape" ,
391
- )
392
- elif isinstance (expected_output_shape , list ):
393
- self .assertIsInstance (output , list )
394
- self .assertEqual (
395
- len (output ),
396
- len (expected_output_shape ),
397
- msg = "Unexpected number of outputs" ,
398
- )
399
- output_shape = [v .shape for v in output ]
400
- self .assertEqual (
401
- expected_output_shape ,
402
- output_shape ,
403
- msg = "Unexpected output shape" ,
404
- )
405
- else :
406
- raise ValueError (
407
- "The type of expected_output_shape is not supported"
408
- )
358
+
359
+ def verify_shape (expected_shape , x ):
360
+ return expected_shape == x .shape
361
+
362
+ shapes_match = tree .map_structure_up_to (
363
+ output , verify_shape , expected_output_shape , output
364
+ )
365
+ self .assertTrue (
366
+ all (tree .flatten (shapes_match )),
367
+ msg = f"Expected output shapes { expected_output_shape } but "
368
+ f"received { tree .map_structure (lambda x : x .shape , output )} " ,
369
+ )
409
370
if expected_output_dtype is not None :
410
- if isinstance (expected_output_dtype , tuple ):
411
- self .assertIsInstance (output , tuple )
412
- self .assertEqual (
413
- len (output ),
414
- len (expected_output_dtype ),
415
- msg = "Unexpected number of outputs" ,
416
- )
417
- output_dtype = tuple (
418
- backend .standardize_dtype (v .dtype ) for v in output
419
- )
420
- self .assertEqual (
421
- expected_output_dtype ,
422
- output_dtype ,
423
- msg = "Unexpected output dtype" ,
424
- )
425
- elif isinstance (expected_output_dtype , dict ):
426
- self .assertIsInstance (output , dict )
427
- self .assertEqual (
428
- set (output .keys ()),
429
- set (expected_output_dtype .keys ()),
430
- msg = "Unexpected output dict keys" ,
431
- )
432
- output_dtype = {
433
- k : backend .standardize_dtype (v .dtype )
434
- for k , v in output .items ()
435
- }
436
- self .assertEqual (
437
- expected_output_dtype ,
438
- output_dtype ,
439
- msg = "Unexpected output dtype" ,
440
- )
441
- elif isinstance (expected_output_dtype , list ):
442
- self .assertIsInstance (output , list )
443
- self .assertEqual (
444
- len (output ),
445
- len (expected_output_dtype ),
446
- msg = "Unexpected number of outputs" ,
447
- )
448
- output_dtype = [
449
- backend .standardize_dtype (v .dtype ) for v in output
450
- ]
451
- self .assertEqual (
452
- expected_output_dtype ,
453
- output_dtype ,
454
- msg = "Unexpected output dtype" ,
455
- )
456
- else :
457
- output_dtype = tree .flatten (output )[0 ].dtype
458
- self .assertEqual (
459
- expected_output_dtype ,
460
- backend .standardize_dtype (output_dtype ),
461
- msg = "Unexpected output dtype" ,
462
- )
371
+
372
+ def verify_dtype (expected_dtype , x ):
373
+ return expected_dtype == backend .standardize_dtype (x .dtype )
374
+
375
+ dtypes_match = tree .map_structure (
376
+ verify_dtype , expected_output_dtype , output
377
+ )
378
+ self .assertTrue (
379
+ all (tree .flatten (dtypes_match )),
380
+ msg = f"Expected output dtypes { expected_output_dtype } but "
381
+ f"received { tree .map_structure (lambda x : x .dtype , output )} " ,
382
+ )
463
383
if expected_output_sparse :
464
384
for x in tree .flatten (output ):
465
385
self .assertSparse (x )
0 commit comments