Skip to content

Commit 18a23c2

Browse files
authored
Merge pull request #4929 from martin-frbg/issue4905
Fix CBLAS_?GEMMT filling in the wrong triangle for Row-Major
2 parents 5a79446 + 7ba6591 commit 18a23c2

File tree

5 files changed

+76
-3
lines changed

5 files changed

+76
-3
lines changed

interface/gemmt.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
319319
lda = LDB;
320320
ldb = LDA;
321321

322-
if (Uplo == CblasUpper) uplo = 0;
323-
if (Uplo == CblasLower) uplo = 1;
322+
if (Uplo == CblasUpper) uplo = 1;
323+
if (Uplo == CblasLower) uplo = 0;
324324

325325
if (TransB == CblasNoTrans)
326326
transa = 0;

utest/test_extensions/test_cgemmt.c

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,28 @@ static void cgemmt_trusted(char api, enum CBLAS_ORDER order, char uplo, char tra
8181

8282
ldc *= 2;
8383

84+
#ifndef NO_CBLAS
85+
if (order == CblasRowMajor) {
86+
if (uplo == 'U' || uplo == CblasUpper)
87+
{
88+
for (i = 0; i < m; i++)
89+
for (j = i * 2; j < m * 2; j+=2){
90+
data_cgemmt.c_verify[i * ldc + j] =
91+
data_cgemmt.c_gemm[i * ldc + j];
92+
data_cgemmt.c_verify[i * ldc + j + 1] =
93+
data_cgemmt.c_gemm[i * ldc + j + 1];
94+
}
95+
} else {
96+
for (i = 0; i < m; i++)
97+
for (j = 0; j <= i * 2; j+=2){
98+
data_cgemmt.c_verify[i * ldc + j] =
99+
data_cgemmt.c_gemm[i * ldc + j];
100+
data_cgemmt.c_verify[i * ldc + j + 1] =
101+
data_cgemmt.c_gemm[i * ldc + j + 1];
102+
}
103+
}
104+
} else
105+
#endif
84106
if (uplo == 'L' || uplo == CblasLower)
85107
{
86108
for (i = 0; i < m; i++)

utest/test_extensions/test_dgemmt.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,21 @@ static void dgemmt_trusted(char api, enum CBLAS_ORDER order, char uplo, char tra
7777
else
7878
cblas_dgemm(order, transa, transb, m, m, k, alpha, data_dgemmt.a_test, lda,
7979
data_dgemmt.b_test, ldb, beta, data_dgemmt.c_gemm, ldc);
80+
81+
if (order == CblasRowMajor) {
82+
if (uplo == 'U' || uplo == CblasUpper)
83+
{
84+
for (i = 0; i < m; i++)
85+
for (j = i; j < m; j++)
86+
data_dgemmt.c_verify[i * ldc + j] =
87+
data_dgemmt.c_gemm[i * ldc + j];
88+
} else {
89+
for (i = 0; i < m; i++)
90+
for (j = 0; j <= i; j++)
91+
data_dgemmt.c_verify[i * ldc + j] =
92+
data_dgemmt.c_gemm[i * ldc + j];
93+
}
94+
}else
8095
#endif
8196

8297
if (uplo == 'L' || uplo == CblasLower)

utest/test_extensions/test_sgemmt.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,21 @@ static void sgemmt_trusted(char api, enum CBLAS_ORDER order, char uplo, char tra
7777
else
7878
cblas_sgemm(order, transa, transb, m, m, k, alpha, data_sgemmt.a_test, lda,
7979
data_sgemmt.b_test, ldb, beta, data_sgemmt.c_gemm, ldc);
80+
if (order == CblasRowMajor) {
81+
if (uplo == 'U' || uplo == CblasUpper)
82+
{
83+
for (i = 0; i < m; i++)
84+
for (j = i; j < m; j++)
85+
data_sgemmt.c_verify[i * ldc + j] =
86+
data_sgemmt.c_gemm[i * ldc + j];
87+
} else {
88+
for (i = 0; i < m; i++)
89+
for (j = 0; j <= i; j++)
90+
data_sgemmt.c_verify[i * ldc + j] =
91+
data_sgemmt.c_gemm[i * ldc + j];
92+
}
93+
94+
} else
8095
#endif
8196

8297
if (uplo == 'L' || uplo == CblasLower)

utest/test_extensions/test_zgemmt.c

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,28 @@ static void zgemmt_trusted(char api, enum CBLAS_ORDER order, char uplo, char tra
8080
#endif
8181

8282
ldc *= 2;
83-
83+
#ifndef NO_CBLAS
84+
if (order == CblasRowMajor) {
85+
if (uplo == 'U' || uplo == CblasUpper)
86+
{
87+
for (i = 0; i < m; i++)
88+
for (j = i * 2; j < m * 2; j+=2){
89+
data_zgemmt.c_verify[i * ldc + j] =
90+
data_zgemmt.c_gemm[i * ldc + j];
91+
data_zgemmt.c_verify[i * ldc + j + 1] =
92+
data_zgemmt.c_gemm[i * ldc + j + 1];
93+
}
94+
} else {
95+
for (i = 0; i < m; i++)
96+
for (j = 0; j <= i * 2; j+=2){
97+
data_zgemmt.c_verify[i * ldc + j] =
98+
data_zgemmt.c_gemm[i * ldc + j];
99+
data_zgemmt.c_verify[i * ldc + j + 1] =
100+
data_zgemmt.c_gemm[i * ldc + j + 1];
101+
}
102+
}
103+
}else
104+
#endif
84105
if (uplo == 'L' || uplo == CblasLower)
85106
{
86107
for (i = 0; i < m; i++)

0 commit comments

Comments
 (0)