Skip to content

Commit d75d702

Browse files
authored
Broadcast lower-rank tensors during batched matmul (#585)
When performing a matmul on two tensors with mismatched ranks, at least one of which is greater than 3, broadcast the lower-rank tensor. This also fixes a bug in the batched cov transform. Signed-off-by: Thomas Benson <[email protected]>
1 parent 52109df commit d75d702

File tree

3 files changed

+72
-6
lines changed

3 files changed

+72
-6
lines changed

include/matx/transforms/matmul.h

+9-5
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,9 @@ class matxMatMulHandle_t {
808808

809809
// Prep for batch looping
810810
using shape_type = typename TensorTypeA::desc_type::shape_type;
811-
[[maybe_unused]] std::array<shape_type, TensorTypeA::Rank()> idx{0};
811+
[[maybe_unused]] std::array<shape_type, TensorTypeA::Rank()> a_idx{0};
812+
[[maybe_unused]] std::array<shape_type, TensorTypeB::Rank()> b_idx{0};
813+
[[maybe_unused]] std::array<shape_type, TensorTypeC::Rank()> c_idx{0};
812814
[[maybe_unused]] auto a_shape = a.Shape();
813815
[[maybe_unused]] size_t total_iter = 1;
814816

@@ -855,9 +857,9 @@ class matxMatMulHandle_t {
855857
for (size_t iter = 0; iter < total_iter; iter++) {
856858

857859
// Get pointers into A/B/C for this round
858-
auto ap = std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, idx);
859-
auto bp = std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, idx);
860-
auto cp = std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, idx);
860+
auto ap = std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, a_idx);
861+
auto bp = std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, b_idx);
862+
auto cp = std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, c_idx);
861863
auto res = cublasLtMatmul(
862864
ltHandle, operationDesc, &salpha, (void *)ap,
863865
Adesc, (void *)bp, Bdesc, &sbeta,
@@ -868,7 +870,9 @@ class matxMatMulHandle_t {
868870
MATX_ASSERT(res == CUBLAS_STATUS_SUCCESS, matxMatMulError);
869871

870872
// Update all but the last 3 indices
871-
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(a_adj, idx, 3);
873+
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(a_adj, a_idx, 3);
874+
UpdateIndices<TensorTypeB, shape_type, TensorTypeB::Rank()>(b_adj, b_idx, 3);
875+
UpdateIndices<TensorTypeC, shape_type, TensorTypeC::Rank()>(c_adj, c_idx, 3);
872876
}
873877
}
874878
}

test/00_transform/Cov.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,12 @@ TYPED_TEST(CovarianceTestFloatTypes, BatchedCov)
101101

102102
(batched_out = cov(batched_in)).run();
103103

104+
cudaDeviceSynchronize();
105+
104106
for (int im = 0; im < m; im++) {
105107
for (int in = 0; in < n; in++) {
106108
for (int ik = 0; ik < k; ik++) {
107-
auto bv = slice<2>(batched_out, {im,in,ik,0,0}, {matxDropDim,matxDropDim,matxDropDim,matxKeepDim,matxKeepDim});
109+
auto bv = slice<2>(batched_out, {im,in,ik,0,0}, {matxDropDim,matxDropDim,matxDropDim,matxEnd,matxEnd});
108110
MATX_TEST_ASSERT_COMPARE(this->pb, bv, "c_cov", this->thresh);
109111
}
110112
}

test/00_transform/MatMul.cu

+60
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,66 @@ TYPED_TEST(MatMulTestFloatNonHalfTypes, MatMulOp)
659659
MATX_EXIT_HANDLER();
660660
}
661661

662+
TYPED_TEST(MatMulTestFloatNonHalfTypes, MatMulBroadcast)
663+
{
664+
MATX_ENTER_HANDLER();
665+
666+
constexpr index_t n = 16;
667+
constexpr index_t b = 8;
668+
constexpr index_t x = 3;
669+
constexpr index_t y = 4;
670+
671+
tensor_t<TypeParam, 2> eye2{{n, n}};
672+
tensor_t<TypeParam, 5> a5{{x, y, b, n, n}};
673+
tensor_t<TypeParam, 5> c5{{x, y, b, n, n}};
674+
675+
const TypeParam two { 2.0 };
676+
const TypeParam three { 3.0 };
677+
678+
(eye2 = two*eye<TypeParam>({n,n})).run();
679+
(a5 = three).run();
680+
681+
(c5 = 0).run();
682+
// Broadcast eye2, scaling each entry in a5 by 2
683+
(c5 = matmul(eye2, a5)).run();
684+
685+
cudaDeviceSynchronize();
686+
687+
for (index_t i0 = 0; i0 < x; i0++)
688+
for (index_t i1 = 0; i1 < y; i1++)
689+
for (index_t i2 = 0; i2 < b; i2++)
690+
for (index_t i3 = 0; i3 < n; i3++)
691+
for (index_t i4 = 0; i4 < n; i4++) {
692+
if constexpr (is_complex_v<TypeParam>) {
693+
ASSERT_NEAR(c5(i0,i1,i2,i3,i4).real(), 2.0*a5(i0,i1,i2,i3,i4).real(), this->thresh);
694+
ASSERT_NEAR(c5(i0,i1,i2,i3,i4).imag(), 2.0*a5(i0,i1,i2,i3,i4).imag(), this->thresh);
695+
} else {
696+
ASSERT_NEAR(c5(i0,i1,i2,i3,i4), two*a5(i0,i1,i2,i3,i4), this->thresh);
697+
}
698+
}
699+
700+
(c5 = 0).run();
701+
// Broadcast eye2, scaling each entry in a5 by 2
702+
(c5 = matmul(a5, eye2)).run();
703+
704+
cudaDeviceSynchronize();
705+
706+
for (index_t i0 = 0; i0 < x; i0++)
707+
for (index_t i1 = 0; i1 < y; i1++)
708+
for (index_t i2 = 0; i2 < b; i2++)
709+
for (index_t i3 = 0; i3 < n; i3++)
710+
for (index_t i4 = 0; i4 < n; i4++) {
711+
if constexpr (is_complex_v<TypeParam>) {
712+
ASSERT_NEAR(c5(i0,i1,i2,i3,i4).real(), 2.0*a5(i0,i1,i2,i3,i4).real(), this->thresh);
713+
ASSERT_NEAR(c5(i0,i1,i2,i3,i4).imag(), 2.0*a5(i0,i1,i2,i3,i4).imag(), this->thresh);
714+
} else {
715+
ASSERT_NEAR(c5(i0,i1,i2,i3,i4), two*a5(i0,i1,i2,i3,i4), this->thresh);
716+
}
717+
}
718+
719+
MATX_EXIT_HANDLER();
720+
}
721+
662722
TYPED_TEST(MatMulTestFloatTypes, MediumMatVec)
663723
{
664724
MATX_ENTER_HANDLER();

0 commit comments

Comments
 (0)