diff --git a/src/type_traits/complex_help.hpp b/src/type_traits/complex_help.hpp index 79e0e920a4..6a309084a9 100644 --- a/src/type_traits/complex_help.hpp +++ b/src/type_traits/complex_help.hpp @@ -12,6 +12,9 @@ #ifndef QMCPLUSPLUS_COMPLEX_HELP_HPP #define QMCPLUSPLUS_COMPLEX_HELP_HPP +#include +#include + namespace qmcplusplus { template @@ -26,14 +29,21 @@ using IsComplex = std::enable_if_t::value, bool>; template using IsReal = std::enable_if_t::value, bool>; -template -struct RealAlias_impl {}; +template +struct RealAlias_impl +{}; -template -struct RealAlias_impl> { using value_type = T; }; +template +struct RealAlias_impl> +{ + using value_type = T; +}; -template -struct RealAlias_impl> { using value_type = typename T::value_type; }; +template +struct RealAlias_impl> +{ + using value_type = typename T::value_type; +}; /** If you have a function templated on a value that can be real or complex * and you need to get the base Real type if its complex or just the real. @@ -41,9 +51,32 @@ struct RealAlias_impl> { using value_type = typename T::value_ty * If you try to do this on anything but a fp or a std::complex you will * get a compilation error. */ -template +template using RealAlias = typename RealAlias_impl::value_type; +template +struct ValueAlias_impl +{}; + +template +struct ValueAlias_impl> +{ + using value_type = TREAL; +}; + +template +struct ValueAlias_impl> +{ + using value_type = std::complex; +}; + +/** If you need to make a value type of a given precision based on a reference value type + * set the desired POD float point type as TREAL and set the reference type as TREF. + * If TREF is real/complex, the generated Value type is real/complex. + */ +template::value>> +using ValueAlias = typename ValueAlias_impl::value_type; + ///real part of a scalar. Cannot be replaced by std::real due to AFQMC specific needs. inline float real(const float& c) { return c; } inline double real(const double& c) { return c; } @@ -59,7 +92,7 @@ inline float conj(const float& c) { return c; } inline double conj(const double& c) { return c; } inline std::complex conj(const std::complex& c) { return std::conj(c); } inline std::complex conj(const std::complex& c) { return std::conj(c); } - + } // namespace qmcplusplus #endif diff --git a/src/type_traits/tests/CMakeLists.txt b/src/type_traits/tests/CMakeLists.txt index 5761a14c90..08bf12935c 100644 --- a/src/type_traits/tests/CMakeLists.txt +++ b/src/type_traits/tests/CMakeLists.txt @@ -14,7 +14,7 @@ set(SRC_DIR type_traits) set(UTEST_EXE test_${SRC_DIR}) set(UTEST_NAME deterministic-unit_test_${SRC_DIR}) -set(TEST_SRCS test_qmctypes.cpp test_template_types.cpp) +set(TEST_SRCS test_qmctypes.cpp test_template_types.cpp test_complex_helper.cpp) add_executable(${UTEST_EXE} ${TEST_SRCS}) target_link_libraries(${UTEST_EXE} catch_main containers) diff --git a/src/type_traits/tests/test_complex_helper.cpp b/src/type_traits/tests/test_complex_helper.cpp new file mode 100644 index 0000000000..b93b0717fa --- /dev/null +++ b/src/type_traits/tests/test_complex_helper.cpp @@ -0,0 +1,46 @@ +////////////////////////////////////////////////////////////////////////////////////// +// This file is distributed under the University of Illinois/NCSA Open Source License. +// See LICENSE file in top directory for details. +// +// Copyright (c) 2023 QMCPACK developers +// +// File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory +// +// File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory +////////////////////////////////////////////////////////////////////////////////////// + +#include "catch.hpp" +#include "type_traits/complex_help.hpp" + +namespace qmcplusplus +{ +template +class TestComplexHelper +{ + using Cmplx = std::complex

; + using Real = RealAlias; + using CmplxRebuild = ValueAlias; + using RealRebuild = ValueAlias; + +public: + void run() + { + Cmplx aa; + CmplxRebuild bb; + aa = bb; + + Real cc; + RealRebuild dd(0); + cc = dd; + } +}; + +TEST_CASE("complex_helper", "[type_traits]") +{ + TestComplexHelper float_test; + float_test.run(); + TestComplexHelper double_test; + double_test.run(); +} + +} // namespace qmcplusplus