From 4d5e0610ac1430bacb9c364816dd084b825c888a Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Mon, 22 Jan 2024 08:36:16 -0800 Subject: [PATCH 1/3] Add abs2() operator for squared abs() Add abs2() operator that is equivalent to abs()*abs() without the extra loads and the unnecessary sqrt operators associated with complex abs() calculations. --- docs_input/api/math/misc/abs2.rst | 20 +++++++ include/matx/operators/scalar_ops.h | 14 +++++ include/matx/operators/unary_operators.h | 10 ++++ test/00_operators/OperatorTests.cu | 67 ++++++++++++++++++++++++ 4 files changed, 111 insertions(+) create mode 100644 docs_input/api/math/misc/abs2.rst diff --git a/docs_input/api/math/misc/abs2.rst b/docs_input/api/math/misc/abs2.rst new file mode 100644 index 000000000..c28ef4291 --- /dev/null +++ b/docs_input/api/math/misc/abs2.rst @@ -0,0 +1,20 @@ +.. _abss_func: + +abs2 +==== + +Squared absolute value. For complex numbers, this is the squared +complex magnitude, or real(t)^2 + imag(t)^2. For real numbers, +this is equivalent to the squared value, or t*t. + +.. doxygenfunction:: abs2(Op t) + +Examples +~~~~~~~~ + +.. literalinclude:: ../../../../test/00_operators/OperatorTests.cu + :language: cpp + :start-after: example-begin abs2-test-1 + :end-before: example-end abs2-test-1 + :dedent: + diff --git a/include/matx/operators/scalar_ops.h b/include/matx/operators/scalar_ops.h index 056d23bbc..0031f4fc5 100644 --- a/include/matx/operators/scalar_ops.h +++ b/include/matx/operators/scalar_ops.h @@ -285,6 +285,20 @@ template struct ExpjF { }; template using ExpjOp = UnOp>; +template struct Abs2F { + static __MATX_INLINE__ std::string str() { return "abs2"; } + + static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto op(T v) + { + if constexpr (is_complex_v) { + return v.real() * v.real() + v.imag() * v.imag(); + } + else { + return v * v; + } + } +}; +template using Abs2Op = UnOp>; template static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto _internal_normcdf(T v1) { diff --git a/include/matx/operators/unary_operators.h b/include/matx/operators/unary_operators.h index bdd6847f7..a426efe27 100644 --- a/include/matx/operators/unary_operators.h +++ b/include/matx/operators/unary_operators.h @@ -190,6 +190,15 @@ namespace matx */ Op abs(Op t) {} + /** + * Compute squared absolute value of every element in the tensor. For complex numbers + * this returns the squared magnitude, or real(t)^2 + imag(t)^2. For real numbers + * this returns the squared value, or t*t. + * @param t + * Tensor or operator input + */ + Op abs2(Op t) {} + /** * Compute the sine of every element in the tensor * @param t @@ -379,6 +388,7 @@ namespace matx #endif DEFINE_UNARY_OP(norm, detail::NormOp); DEFINE_UNARY_OP(abs, detail::AbsOp); + DEFINE_UNARY_OP(abs2, detail::Abs2Op); DEFINE_UNARY_OP(sin, detail::SinOp); DEFINE_UNARY_OP(cos, detail::CosOp); DEFINE_UNARY_OP(tan, detail::TanOp); diff --git a/test/00_operators/OperatorTests.cu b/test/00_operators/OperatorTests.cu index d55fd3102..6447699d6 100644 --- a/test/00_operators/OperatorTests.cu +++ b/test/00_operators/OperatorTests.cu @@ -1582,6 +1582,73 @@ TYPED_TEST(OperatorTestsAllExecs, OperatorFuncs) MATX_EXIT_HANDLER(); } +TYPED_TEST(OperatorTestsNumericAllExecs, Abs2) +{ + MATX_ENTER_HANDLER(); + using TestType = std::tuple_element_t<0, TypeParam>; + using ExecType = std::tuple_element_t<1, TypeParam>; + using inner_type = typename inner_op_type_t::type; + + ExecType exec{}; + + auto sync = [&exec]() constexpr { + if constexpr (std::is_same_v) { + cudaDeviceSynchronize(); + } + }; + + if constexpr (std::is_same_v> && + std::is_same_v) { + // example-begin abs2-test-1 + auto x = make_tensor>({}); + auto y = make_tensor({}); + x() = { 1.5f, 2.5f }; + (y = abs2(x)).run(); + cudaDeviceSynchronize(); + ASSERT_NEAR(y(), 1.5f*1.5f+2.5f*2.5f, 1.0e-6); + // example-end abs2-test-1 + } + + auto x = make_tensor({}); + auto y = make_tensor({}); + if constexpr (is_complex_v) { + x() = TestType{2.0, 2.0}; + (y = abs2(x)).run(exec); + sync(); + ASSERT_NEAR(y(), 8.0, 1.0e-6); + } else { + x() = 2.0; + (y = abs2(x)).run(exec); + sync(); + ASSERT_NEAR(y(), 4.0, 1.0e-6); + + // Test with higher rank tensor + auto x3 = make_tensor({3,3,3}); + auto y3 = make_tensor({3,3,3}); + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + for (int k = 0; k < 3; k++) { + x3(i,j,k) = static_cast(i*9 + j*3 + k); + } + } + } + + (y3 = abs2(x3)).run(exec); + sync(); + + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + for (int k = 0; k < 3; k++) { + TestType v = static_cast(i*9 + j*3 + k); + ASSERT_NEAR(y3(i,j,k), v*v, 1.0e-6); + } + } + } + } + + MATX_EXIT_HANDLER(); +} + TYPED_TEST(OperatorTestsFloatNonComplexAllExecs, OperatorFuncsR2C) { MATX_ENTER_HANDLER(); From b95de224ef1ad83fb8a30c8ad9d4250d116537cc Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Mon, 22 Jan 2024 08:41:15 -0800 Subject: [PATCH 2/3] Fix typo in abs2() documentation --- docs_input/api/math/misc/abs2.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs_input/api/math/misc/abs2.rst b/docs_input/api/math/misc/abs2.rst index c28ef4291..94cdb4ab3 100644 --- a/docs_input/api/math/misc/abs2.rst +++ b/docs_input/api/math/misc/abs2.rst @@ -1,4 +1,4 @@ -.. _abss_func: +.. _abs2_func: abs2 ==== From 97ec86247a9d88de7f7907473e89e9b8ade25d58 Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Mon, 22 Jan 2024 09:43:24 -0800 Subject: [PATCH 3/3] Use superscript instead of ^2 in abs2 rst doc --- docs_input/api/math/misc/abs2.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs_input/api/math/misc/abs2.rst b/docs_input/api/math/misc/abs2.rst index 94cdb4ab3..8e8736cf6 100644 --- a/docs_input/api/math/misc/abs2.rst +++ b/docs_input/api/math/misc/abs2.rst @@ -4,8 +4,8 @@ abs2 ==== Squared absolute value. For complex numbers, this is the squared -complex magnitude, or real(t)^2 + imag(t)^2. For real numbers, -this is equivalent to the squared value, or t*t. +complex magnitude, or real(t)\ :sup:`2` + imag(t)\ :sup:`2`. For real numbers, +this is equivalent to the squared value, or t\ :sup:`2`. .. doxygenfunction:: abs2(Op t)