@@ -659,6 +659,66 @@ TYPED_TEST(MatMulTestFloatNonHalfTypes, MatMulOp)
659
659
MATX_EXIT_HANDLER ();
660
660
}
661
661
662
+ TYPED_TEST (MatMulTestFloatNonHalfTypes, MatMulBroadcast)
663
+ {
664
+ MATX_ENTER_HANDLER ();
665
+
666
+ constexpr index_t n = 16 ;
667
+ constexpr index_t b = 8 ;
668
+ constexpr index_t x = 3 ;
669
+ constexpr index_t y = 4 ;
670
+
671
+ tensor_t <TypeParam, 2 > eye2{{n, n}};
672
+ tensor_t <TypeParam, 5 > a5{{x, y, b, n, n}};
673
+ tensor_t <TypeParam, 5 > c5{{x, y, b, n, n}};
674
+
675
+ const TypeParam two { 2.0 };
676
+ const TypeParam three { 3.0 };
677
+
678
+ (eye2 = two*eye<TypeParam>({n,n})).run ();
679
+ (a5 = three).run ();
680
+
681
+ (c5 = 0 ).run ();
682
+ // Broadcast eye2, scaling each entry in a5 by 2
683
+ (c5 = matmul (eye2, a5)).run ();
684
+
685
+ cudaDeviceSynchronize ();
686
+
687
+ for (index_t i0 = 0 ; i0 < x; i0++)
688
+ for (index_t i1 = 0 ; i1 < y; i1++)
689
+ for (index_t i2 = 0 ; i2 < b; i2++)
690
+ for (index_t i3 = 0 ; i3 < n; i3++)
691
+ for (index_t i4 = 0 ; i4 < n; i4++) {
692
+ if constexpr (is_complex_v<TypeParam>) {
693
+ ASSERT_NEAR (c5 (i0,i1,i2,i3,i4).real (), 2.0 *a5 (i0,i1,i2,i3,i4).real (), this ->thresh );
694
+ ASSERT_NEAR (c5 (i0,i1,i2,i3,i4).imag (), 2.0 *a5 (i0,i1,i2,i3,i4).imag (), this ->thresh );
695
+ } else {
696
+ ASSERT_NEAR (c5 (i0,i1,i2,i3,i4), two*a5 (i0,i1,i2,i3,i4), this ->thresh );
697
+ }
698
+ }
699
+
700
+ (c5 = 0 ).run ();
701
+ // Broadcast eye2, scaling each entry in a5 by 2
702
+ (c5 = matmul (a5, eye2)).run ();
703
+
704
+ cudaDeviceSynchronize ();
705
+
706
+ for (index_t i0 = 0 ; i0 < x; i0++)
707
+ for (index_t i1 = 0 ; i1 < y; i1++)
708
+ for (index_t i2 = 0 ; i2 < b; i2++)
709
+ for (index_t i3 = 0 ; i3 < n; i3++)
710
+ for (index_t i4 = 0 ; i4 < n; i4++) {
711
+ if constexpr (is_complex_v<TypeParam>) {
712
+ ASSERT_NEAR (c5 (i0,i1,i2,i3,i4).real (), 2.0 *a5 (i0,i1,i2,i3,i4).real (), this ->thresh );
713
+ ASSERT_NEAR (c5 (i0,i1,i2,i3,i4).imag (), 2.0 *a5 (i0,i1,i2,i3,i4).imag (), this ->thresh );
714
+ } else {
715
+ ASSERT_NEAR (c5 (i0,i1,i2,i3,i4), two*a5 (i0,i1,i2,i3,i4), this ->thresh );
716
+ }
717
+ }
718
+
719
+ MATX_EXIT_HANDLER ();
720
+ }
721
+
662
722
TYPED_TEST (MatMulTestFloatTypes, MediumMatVec)
663
723
{
664
724
MATX_ENTER_HANDLER ();
0 commit comments