From 2360abec8dd7d1b297e167c5e6188fd41753a92e Mon Sep 17 00:00:00 2001 From: beomki-yeo Date: Tue, 16 Apr 2024 17:26:23 +0200 Subject: [PATCH] Optimize the cmath operators --- .../algebra/math/impl/cmath_matrix.hpp | 20 ++++----- .../algebra/math/impl/cmath_operators.hpp | 45 ++++++++++--------- tests/common/test_host_basics.hpp | 16 +++++++ 3 files changed, 49 insertions(+), 32 deletions(-) diff --git a/math/cmath/include/algebra/math/impl/cmath_matrix.hpp b/math/cmath/include/algebra/math/impl/cmath_matrix.hpp index 23e3a116..c1cb3b42 100644 --- a/math/cmath/include/algebra/math/impl/cmath_matrix.hpp +++ b/math/cmath/include/algebra/math/impl/cmath_matrix.hpp @@ -69,8 +69,8 @@ struct actor { ALGEBRA_HOST_DEVICE void set_block(input_matrix_type &m, const matrix_type &b, size_type row, size_type col) const { - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { element_getter()(m, i + row, j + col) = element_getter()(b, i, j); } } @@ -92,8 +92,8 @@ struct actor { ALGEBRA_HOST_DEVICE inline matrix_type zero() const { matrix_type ret; - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { element_getter()(ret, i, j) = 0; } } @@ -106,8 +106,8 @@ struct actor { ALGEBRA_HOST_DEVICE inline matrix_type identity() const { matrix_type ret; - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { if (i == j) { element_getter()(ret, i, j) = 1; } else { @@ -123,8 +123,8 @@ struct actor { template ALGEBRA_HOST_DEVICE inline void set_zero(matrix_type &m) const { - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { element_getter()(m, i, j) = 0; } } @@ -135,8 +135,8 @@ struct actor { ALGEBRA_HOST_DEVICE inline void set_identity( matrix_type &m) const { - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { if (i == j) { element_getter()(m, i, j) = 1; } else { diff --git a/math/cmath/include/algebra/math/impl/cmath_operators.hpp b/math/cmath/include/algebra/math/impl/cmath_operators.hpp index 0c56df91..a38581dc 100644 --- a/math/cmath/include/algebra/math/impl/cmath_operators.hpp +++ b/math/cmath/include/algebra/math/impl/cmath_operators.hpp @@ -137,8 +137,8 @@ ALGEBRA_HOST_DEVICE inline array_t, COLS> operator*( array_t, COLS> ret; - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { ret[j][i] = a[j][i] * static_cast(s); } } @@ -154,8 +154,8 @@ ALGEBRA_HOST_DEVICE inline array_t, COLS> operator*( array_t, COLS> ret; - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { ret[j][i] = a[j][i] * static_cast(s); } } @@ -171,8 +171,8 @@ ALGEBRA_HOST_DEVICE inline array_t, COLS> operator*( array_t, COLS> ret; - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { ret[j][i] = a[j][i] * static_cast(s); } } @@ -188,8 +188,8 @@ ALGEBRA_HOST_DEVICE inline array_t, COLS> operator*( array_t, COLS> ret; - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { ret[j][i] = a[j][i] * static_cast(s); } } @@ -206,16 +206,17 @@ ALGEBRA_HOST_DEVICE inline array_t, O> operator*( array_t, O> C; - for (size_type i = 0; i < M; ++i) { - for (size_type j = 0; j < O; ++j) { - - scalar_t val = 0; + for (size_type j = 0; j < O; ++j) { + for (size_type i = 0; i < M; ++i) { + C[j][i] = 0.f; + } + } - for (size_type k = 0; k < N; ++k) { - val += A[k][i] * B[j][k]; + for (size_type i = 0; i < N; ++i) { + for (size_type j = 0; j < O; ++j) { + for (size_type k = 0; k < M; ++k) { + C[j][k] += A[i][k] * B[j][i]; } - - C[j][i] = val; } } @@ -231,8 +232,8 @@ ALGEBRA_HOST_DEVICE inline array_t, COLS> operator+( array_t, COLS> C; - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { C[j][i] = A[j][i] + B[j][i]; } } @@ -249,8 +250,8 @@ ALGEBRA_HOST_DEVICE inline array_t, COLS> operator-( array_t, COLS> C; - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { C[j][i] = A[j][i] - B[j][i]; } } @@ -272,8 +273,8 @@ ALGEBRA_HOST_DEVICE inline array_t operator*( array_t ret{0}; - for (size_type i = 0; i < ROWS; ++i) { - for (size_type j = 0; j < COLS; ++j) { + for (size_type j = 0; j < COLS; ++j) { + for (size_type i = 0; i < ROWS; ++i) { ret[i] += a[j][i] * b[j]; } } diff --git a/tests/common/test_host_basics.hpp b/tests/common/test_host_basics.hpp index 4d05fc0c..f0ca6252 100644 --- a/tests/common/test_host_basics.hpp +++ b/tests/common/test_host_basics.hpp @@ -226,6 +226,22 @@ TYPED_TEST_P(test_host_basics, matrix64) { ASSERT_NEAR(algebra::getter::element(m, 0, 2), 10., this->m_epsilon); ASSERT_NEAR(algebra::getter::element(m, 1, 2), 20., this->m_epsilon); ASSERT_NEAR(algebra::getter::element(m, 2, 2), 30., this->m_epsilon); + + typename TypeParam::template matrix<3, 3> m33; + algebra::getter::element(m33, 0, 0) = 1; + algebra::getter::element(m33, 1, 0) = 2; + algebra::getter::element(m33, 2, 0) = 3; + algebra::getter::element(m33, 0, 1) = 5; + algebra::getter::element(m33, 1, 1) = 6; + algebra::getter::element(m33, 2, 1) = 7; + algebra::getter::element(m33, 0, 2) = 9; + algebra::getter::element(m33, 1, 2) = 10; + algebra::getter::element(m33, 2, 2) = 11; + + const typename TypeParam::vector3 v2 = m33 * v; + ASSERT_NEAR(v2[0], 380., this->m_epsilon); + ASSERT_NEAR(v2[1], 440., this->m_epsilon); + ASSERT_NEAR(v2[2], 500., this->m_epsilon); } // Test matrix operations with 3x3 matrix