diff --git a/docs_input/api/math/misc/abs2.rst b/docs_input/api/math/misc/abs2.rst new file mode 100644 index 000000000..8e8736cf6 --- /dev/null +++ b/docs_input/api/math/misc/abs2.rst @@ -0,0 +1,20 @@ +.. _abs2_func: + +abs2 +==== + +Squared absolute value. For complex numbers, this is the squared +complex magnitude, or real(t)\ :sup:`2` + imag(t)\ :sup:`2`. For real numbers, +this is equivalent to the squared value, or t\ :sup:`2`. + +.. doxygenfunction:: abs2(Op t) + +Examples +~~~~~~~~ + +.. literalinclude:: ../../../../test/00_operators/OperatorTests.cu + :language: cpp + :start-after: example-begin abs2-test-1 + :end-before: example-end abs2-test-1 + :dedent: + diff --git a/include/matx/operators/scalar_ops.h b/include/matx/operators/scalar_ops.h index 056d23bbc..0031f4fc5 100644 --- a/include/matx/operators/scalar_ops.h +++ b/include/matx/operators/scalar_ops.h @@ -285,6 +285,20 @@ template struct ExpjF { }; template using ExpjOp = UnOp>; +template struct Abs2F { + static __MATX_INLINE__ std::string str() { return "abs2"; } + + static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto op(T v) + { + if constexpr (is_complex_v) { + return v.real() * v.real() + v.imag() * v.imag(); + } + else { + return v * v; + } + } +}; +template using Abs2Op = UnOp>; template static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto _internal_normcdf(T v1) { diff --git a/include/matx/operators/unary_operators.h b/include/matx/operators/unary_operators.h index bdd6847f7..a426efe27 100644 --- a/include/matx/operators/unary_operators.h +++ b/include/matx/operators/unary_operators.h @@ -190,6 +190,15 @@ namespace matx */ Op abs(Op t) {} + /** + * Compute squared absolute value of every element in the tensor. For complex numbers + * this returns the squared magnitude, or real(t)^2 + imag(t)^2. For real numbers + * this returns the squared value, or t*t. + * @param t + * Tensor or operator input + */ + Op abs2(Op t) {} + /** * Compute the sine of every element in the tensor * @param t @@ -379,6 +388,7 @@ namespace matx #endif DEFINE_UNARY_OP(norm, detail::NormOp); DEFINE_UNARY_OP(abs, detail::AbsOp); + DEFINE_UNARY_OP(abs2, detail::Abs2Op); DEFINE_UNARY_OP(sin, detail::SinOp); DEFINE_UNARY_OP(cos, detail::CosOp); DEFINE_UNARY_OP(tan, detail::TanOp); diff --git a/test/00_operators/OperatorTests.cu b/test/00_operators/OperatorTests.cu index d55fd3102..6447699d6 100644 --- a/test/00_operators/OperatorTests.cu +++ b/test/00_operators/OperatorTests.cu @@ -1582,6 +1582,73 @@ TYPED_TEST(OperatorTestsAllExecs, OperatorFuncs) MATX_EXIT_HANDLER(); } +TYPED_TEST(OperatorTestsNumericAllExecs, Abs2) +{ + MATX_ENTER_HANDLER(); + using TestType = std::tuple_element_t<0, TypeParam>; + using ExecType = std::tuple_element_t<1, TypeParam>; + using inner_type = typename inner_op_type_t::type; + + ExecType exec{}; + + auto sync = [&exec]() constexpr { + if constexpr (std::is_same_v) { + cudaDeviceSynchronize(); + } + }; + + if constexpr (std::is_same_v> && + std::is_same_v) { + // example-begin abs2-test-1 + auto x = make_tensor>({}); + auto y = make_tensor({}); + x() = { 1.5f, 2.5f }; + (y = abs2(x)).run(); + cudaDeviceSynchronize(); + ASSERT_NEAR(y(), 1.5f*1.5f+2.5f*2.5f, 1.0e-6); + // example-end abs2-test-1 + } + + auto x = make_tensor({}); + auto y = make_tensor({}); + if constexpr (is_complex_v) { + x() = TestType{2.0, 2.0}; + (y = abs2(x)).run(exec); + sync(); + ASSERT_NEAR(y(), 8.0, 1.0e-6); + } else { + x() = 2.0; + (y = abs2(x)).run(exec); + sync(); + ASSERT_NEAR(y(), 4.0, 1.0e-6); + + // Test with higher rank tensor + auto x3 = make_tensor({3,3,3}); + auto y3 = make_tensor({3,3,3}); + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + for (int k = 0; k < 3; k++) { + x3(i,j,k) = static_cast(i*9 + j*3 + k); + } + } + } + + (y3 = abs2(x3)).run(exec); + sync(); + + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + for (int k = 0; k < 3; k++) { + TestType v = static_cast(i*9 + j*3 + k); + ASSERT_NEAR(y3(i,j,k), v*v, 1.0e-6); + } + } + } + } + + MATX_EXIT_HANDLER(); +} + TYPED_TEST(OperatorTestsFloatNonComplexAllExecs, OperatorFuncsR2C) { MATX_ENTER_HANDLER();