diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..5f2670a45 --- /dev/null +++ b/.clang-format @@ -0,0 +1,145 @@ +--- +Language: Cpp +AccessModifierOffset: -4 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignConsecutiveMacros: true +AlignEscapedNewlines: DontAlign +AlignOperands: true +AlignTrailingComments: false +AllowAllArgumentsOnNextLine: true +AllowAllConstructorInitializersOnNextLine: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: Empty +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterCaseLabel: false + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: true + BeforeElse: true + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: BeforeColon +BreakInheritanceList: BeforeColon +BreakStringLiterals: false +ColumnLimit: 100 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 8 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: false +DerivePointerAlignment: true +DisableFormat: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^' + Priority: 2 + - Regex: '^<.*\.h>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseLabels: true +IndentPPDirectives: None +IndentWidth: 4 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' + BasedOnStyle: google + - Language: TextProto + Delimiters: + - pb + - PB + - proto + - PROTO + EnclosingFunctions: + - EqualsProto + - EquivToProto + - PARSE_PARTIAL_TEXT_PROTO + - PARSE_TEST_PROTO + - PARSE_TEXT_PROTO + - ParseTextOrDie + - ParseTextProtoOrDie + CanonicalDelimiter: '' + BasedOnStyle: google +ReflowComments: false +SortIncludes: false +SortUsingDeclarations: false +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: false +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 1 +UseTab: Never +... diff --git a/CMakeLists.txt b/CMakeLists.txt index da4c9517e..e4dce0040 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,9 +34,6 @@ check_compiler_version() if (NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") endif() -if (NOT CMAKE_DISABLE_SYCL) - set(CMAKE_DISABLE_SYCL 0) -endif() #make build variable case insensitive string( TOLOWER "${CMAKE_BUILD_TYPE}" CMAKE_BUILD_TYPE_CASE_INSENSITIVE) @@ -47,6 +44,7 @@ if (${CMAKE_BUILD_TYPE_CASE_INSENSITIVE} STREQUAL "debug") set(USE_SECURITY_FLAGS FALSE) endif() +option(BUILD_UT "Build unit tests" TRUE) option(USE_CODECOV_FLAGS "Calculate code coverage" FALSE) option(WITH_ASAN "Use address sanitizer, can only be used in Debug build" FALSE) @@ -58,6 +56,7 @@ if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) endif() #show build info +message(STATUS "Build unit tests: ${BUILD_UT}") message(STATUS "Installation directory: ${CMAKE_INSTALL_PREFIX}") message(STATUS "Build type: ${CMAKE_BUILD_TYPE_CASE_INSENSITIVE}") message(STATUS "C compiler : ${CMAKE_C_COMPILER}") @@ -93,8 +92,6 @@ include_directories(${LIBFABRIC_INCLUDE_DIR}) link_directories(${MPI_LIB_DIR}) link_directories(${LIBFABRIC_LIB_DIR}) - - set(CCL_INSTALL_UNIT_TESTS "${CMAKE_INSTALL_PREFIX}/tests/unit") set(CMAKE_SKIP_INSTALL_RPATH TRUE) @@ -144,14 +141,40 @@ set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} ${CXX_COMP set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(TRY_ENABLE_SYCL_L0 OFF) + if (COMPUTE_RUNTIME) activate_compute_runtime("${CMAKE_CURRENT_LIST_DIR}/cmake" ${COMPUTE_RUNTIME}) - set(PARENT_COMPUTE_RUNTIME_TARGET_NAME ${COMPUTE_RUNTIME_TARGET_NAME}) + if (NOT COMPUTE_RUNTIME_TARGET_NAME) + message(FATAL_ERROR "Failed to find requested compute runtime: ${COMPUTE_RUNTIME}") + endif() + message(STATUS "COMPUTE_RUNTIME_TARGET_NAME: ${COMPUTE_RUNTIME_TARGET_NAME}") if (${CCL_ENABLE_SYCL_V} STREQUAL 1) option (CCL_ENABLE_SYCL "Enable CCL SYCL runtime" ON) message(STATUS "Enable CCL SYCL runtime") + if (${COMPUTE_RUNTIME_TARGET_NAME} STREQUAL "Intel::SYCL") + set (CCL_ENABLE_SYCL_CHECK_CONTRACT "#if defined(__cplusplus)\n#if !defined(__clang__) || __clang_major__ < 9 || !defined(SYCL_LANGUAGE_VERSION)\n#error This version of CCL configured only for oneAPI DPC++ Compiler\n#endif\n#endif") + execute_process(COMMAND dpcpp -v + OUTPUT_VARIABLE DPCPP_VERSION + ERROR_VARIABLE DPCPP_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_STRIP_TRAILING_WHITESPACE + ) + message(STATUS "DPC++ compiler version:\n" "${DPCPP_VERSION}") + else(${COMPUTE_RUNTIME_TARGET_NAME} STREQUAL "Codeplay::ComputeCpp") + set (CCL_ENABLE_SYCL_CHECK_CONTRACT "#if defined(__cplusplus)\n#if !defined(__clang__) || __clang_major__ < 6\n#error This version of CCL configured only for oneAPI DPC++ Compiler\n#endif\n#endif") + endif() endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${COMPUTE_RUNTIME_FLAGS}") + if ((${COMPUTE_RUNTIME_TARGET_NAME} STREQUAL "Intel::SYCL") AND + ${CCL_ENABLE_SYCL_L0} STREQUAL 1) + set(MULTI_GPU_SUPPORT ON) + elseif(${COMPUTE_RUNTIME_TARGET_NAME} STREQUAL "ze_loader") + set(MULTI_GPU_SUPPORT ON) + endif() + if (MULTI_GPU_SUPPORT) + message(STATUS "Enable multi GPU support using L0") + endif() endif() if(${CMAKE_C_COMPILER_ID} STREQUAL "GNU" AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") @@ -210,6 +233,9 @@ if (CCL_BF16_COMPILER) endif() endif() +add_definitions(-DCCL_GPU_BF16_TRUNCATE) +set(CCL_GPU_BF16_TRUNCATE ON) + set(CCL_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/src) enable_testing() @@ -222,32 +248,6 @@ set(CMAKE_CLANG_FLAGS "${CMAKE_CLANG_FLAGS}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -if (COMPUTE_RUNTIME) - if ((${COMPUTE_RUNTIME_TARGET_NAME} STREQUAL "ze_loader") - OR (${COMPUTE_RUNTIME_TARGET_NAME} STREQUAL "Intel::SYCL")) - set(MULTI_GPU_SUPPORT ON) - activate_compute_runtime("${CMAKE_CURRENT_LIST_DIR}/cmake" L0) - message ("Enable multi GPU support: ${MULTI_GPU_SUPPORT}") - message ("COMPUTE_RUNTIME_TARGET_NAME: ${COMPUTE_RUNTIME_TARGET_NAME}") - endif() -endif(COMPUTE_RUNTIME) - -if (MULTI_GPU_SUPPORT) - option(CCL_GPU_DEVICES_AFFINITY_ENABLE "Enable L0" ON) - if(CCL_GPU_DEVICES_AFFINITY_ENABLE) - set(CCL_GPU_DEVICES_AFFINITY_MASK_SIZE 4) - message ("Set L0 device mask affinity size: ${CCL_GPU_DEVICES_AFFINITY_MASK_SIZE}") - endif() -endif(MULTI_GPU_SUPPORT) - -if (CCL_ENABLE_SYCL) - if (${PARENT_COMPUTE_RUNTIME_TARGET_NAME} STREQUAL "Intel::SYCL") - set (CCL_ENABLE_SYCL_CHECK_CONTRACT "#if defined(__cplusplus)\n#if !defined(__clang__) || __clang_major__ < 9 || !defined(CL_SYCL_LANGUAGE_VERSION)\n#error This version of CCL configured only for oneAPI DPC++ Compiler\n#endif\n#endif") - else(${PARENT_COMPUTE_RUNTIME_TARGET_NAME} STREQUAL "Codeplay::ComputeCpp") - set (CCL_ENABLE_SYCL_CHECK_CONTRACT "#if defined(__cplusplus)\n#if !defined(__clang__) || __clang_major__ < 6\n#error This version of CCL configured only for oneAPI DPC++ Compiler\n#endif\n#endif") - endif() -endif() - #generate & install vars.sh configure_file(cmake/vars.sh.in ${CMAKE_CURRENT_BINARY_DIR}/vars.sh @ONLY) configure_file(cmake/setvars.sh.in ${CMAKE_CURRENT_BINARY_DIR}/setvars.sh @ONLY) @@ -259,15 +259,15 @@ install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/ccl DESTINATION ${CCL_INSTALL_MODUL install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/third-party-programs.txt DESTINATION ${CCL_INSTALL_LICENSE}) install(PROGRAMS ${PROJECT_SOURCE_DIR}/LICENSE DESTINATION ${CCL_INSTALL_LICENSE}) -set(CCL_MAJOR_VERSION "0") -set(CCL_MINOR_VERSION "10") +set(CCL_MAJOR_VERSION "2021") +set(CCL_MINOR_VERSION "1") set(CCL_UPDATE_VERSION "0") -set(CCL_PRODUCT_STATUS "beta") +set(CCL_PRODUCT_STATUS "Gold") string(TIMESTAMP CCL_PRODUCT_BUILD_DATE "%Y-%m-%dT %H:%M:%SZ") get_vcs_properties("git") set(CCL_PRODUCT_FULL "${CCL_PRODUCT_STATUS}-${CCL_MAJOR_VERSION}.${CCL_MINOR_VERSION}.${CCL_UPDATE_VERSION} ${CCL_PRODUCT_BUILD_DATE} ${VCS_INFO}") -configure_file(${PROJECT_SOURCE_DIR}/include/oneapi/ccl/ccl_config.h.in "${CMAKE_CURRENT_BINARY_DIR}/include/oneapi/ccl/ccl_config.h") -file(COPY "${CMAKE_CURRENT_BINARY_DIR}/include/oneapi/ccl/ccl_config.h" DESTINATION ${PROJECT_SOURCE_DIR}/include/oneapi/ccl) +configure_file(${PROJECT_SOURCE_DIR}/include/oneapi/ccl/config.h.in "${CMAKE_CURRENT_BINARY_DIR}/include/oneapi/ccl/config.h") +file(COPY "${CMAKE_CURRENT_BINARY_DIR}/include/oneapi/ccl/config.h" DESTINATION ${PROJECT_SOURCE_DIR}/include/oneapi/ccl) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) @@ -282,4 +282,7 @@ if (CCL_ENABLE_SYCL) endif() add_subdirectory(tests/functional) -#add_subdirectory(tests/unit) + +if (BUILD_UT) + #add_subdirectory(tests/unit) +endif() diff --git a/cmake/FindIntelSYCL.cmake b/cmake/FindIntelSYCL.cmake index 1a773f823..ed7775dd0 100644 --- a/cmake/FindIntelSYCL.cmake +++ b/cmake/FindIntelSYCL.cmake @@ -26,14 +26,13 @@ endif() set(OPENCLROOT "${dpcpp_root_hints}/include/sycl/CL/") -if(MULTI_GPU_SUPPORT) - find_package(L0 REQUIRED) +if(TRY_ENABLE_SYCL_L0) + find_package(L0) if(LevelZero_FOUND) set(COMPUTE_RUNTIME_NAME ze_loader) endif() endif() - if (NOT COMPUTE_RUNTIME_NAME) message("Not OpenCL or L0") endif() diff --git a/cmake/FindL0.cmake b/cmake/FindL0.cmake index da69bf6ef..607b61de8 100644 --- a/cmake/FindL0.cmake +++ b/cmake/FindL0.cmake @@ -16,6 +16,10 @@ endif() list(INSERT CMAKE_PREFIX_PATH 0 ${l0_root_hints}) +if (TARGET ze_loader) + set(LevelZero_FOUND ON) +endif() + if(NOT TARGET ze_loader) find_path(LevelZero_INCLUDE_DIR NAMES ze_api.h @@ -23,8 +27,10 @@ if(NOT TARGET ze_loader) ENV ZE_ROOT ${l0_root_hints} PATH_SUFFIXES + include + include/level_zero local/include - local/include/level_zero/ + local/include/level_zero NO_DEFAULT_PATH ) @@ -35,8 +41,9 @@ if(NOT TARGET ze_loader) ${l0_root_hints} PATH_SUFFIXES lib + lib/x86_64-linux-gnu + lib/level_zero local/lib - lib/level_zero/ local/lib/level_zero NO_DEFAULT_PATH ) @@ -58,15 +65,14 @@ if(NOT TARGET ze_loader) message("L0 is using OpenCL interoperability") list(APPEND LevelZero_INCLUDE_DIRS ${OpenCL_INCLUDE_DIRS}) endif() + add_library(ze_loader INTERFACE IMPORTED) + set_target_properties(ze_loader + PROPERTIES INTERFACE_LINK_LIBRARIES "${LevelZero_LIBRARIES}" + ) + set_target_properties(ze_loader + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${LevelZero_INCLUDE_DIRS}" + ) endif() - - add_library(ze_loader INTERFACE IMPORTED) - set_target_properties(ze_loader - PROPERTIES INTERFACE_LINK_LIBRARIES "${LevelZero_LIBRARIES}" - ) - set_target_properties(ze_loader - PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${LevelZero_INCLUDE_DIRS}" - ) endif() # Reverting the CMAKE_PREFIX_PATH to its original state diff --git a/cmake/helpers.cmake b/cmake/helpers.cmake index 4002320ea..5d6cf2063 100644 --- a/cmake/helpers.cmake +++ b/cmake/helpers.cmake @@ -45,7 +45,10 @@ endfunction(get_vcs_properties) function(activate_compute_runtime MODULES_PATH COMPUTE_RUNTIME) string( TOLOWER "${COMPUTE_RUNTIME}" COMPUTE_RUNTIME) + set(CCL_ENABLE_SYCL_V 0 PARENT_SCOPE) + set(CCL_ENABLE_SYCL_L0 0 PARENT_SCOPE) + message("Search Compute Runtime by MODULES_PATH: ${MODULES_PATH}") list(APPEND CMAKE_MODULE_PATH "${MODULES_PATH}") @@ -54,9 +57,12 @@ function(activate_compute_runtime MODULES_PATH COMPUTE_RUNTIME) SET (COMPUTE_RUNTIME_LOAD_MODULE "ComputeCpp" CACHE STRING "COMPUTE_RUNTIME=${COMPUTE_RUNTIME} requested. Using ComputeCpp provider") + find_package(${COMPUTE_RUNTIME_LOAD_MODULE} REQUIRED) - set (CCL_ENABLE_SYCL_V 1 PARENT_SCOPE) + if(NOT ComputeCpp_FOUND) + message(FATAL_ERROR "Failed to find ComputeCpp") + endif() # remember compilation flags, because flag required for OBJECTS target # but if we use `target_link_libraries`, then these flags applied to all compiler options @@ -74,8 +80,17 @@ function(activate_compute_runtime MODULES_PATH COMPUTE_RUNTIME) SET (COMPUTE_RUNTIME_LOAD_MODULE "IntelSYCL" CACHE STRING "COMPUTE_RUNTIME=${COMPUTE_RUNTIME} requested. Using DPC++ provider") + find_package(${COMPUTE_RUNTIME_LOAD_MODULE} REQUIRED) + if(NOT IntelSYCL_FOUND) + message(FATAL_ERROR "Failed to find IntelSYCL") + endif() + + if(LevelZero_FOUND) + set(CCL_ENABLE_SYCL_L0 1 PARENT_SCOPE) + endif() + set(CCL_ENABLE_SYCL_V 1 PARENT_SCOPE) # remember compilation flags, because flag required for OBJECTS target @@ -93,8 +108,14 @@ function(activate_compute_runtime MODULES_PATH COMPUTE_RUNTIME) SET (COMPUTE_RUNTIME_LOAD_MODULE "L0" CACHE STRING "COMPUTE_RUNTIME=${COMPUTE_RUNTIME} requested") + find_package(${COMPUTE_RUNTIME_LOAD_MODULE} REQUIRED) + if(NOT LevelZero_FOUND) + message(STATUS "Can not find level-zero") + return() + endif() + # No compiler flags set (COMPUTE_RUNTIME_CXXFLAGS_LOCAL "") @@ -103,6 +124,10 @@ function(activate_compute_runtime MODULES_PATH COMPUTE_RUNTIME) set (COMPUTE_RUNTIME_TARGET_NAME ze_loader PARENT_SCOPE) endif() + if (NOT COMPUTE_RUNTIME_TARGET_NAME) + message(FATAL_ERROR "Failed to find requested compute runtime: ${COMPUTE_RUNTIME}") + endif() + # extract target properties get_target_property(COMPUTE_RUNTIME_INCLUDE_DIRS_LOCAL ${COMPUTE_RUNTIME_TARGET_NAME} INTERFACE_INCLUDE_DIRECTORIES) diff --git a/cmake/vars.sh.in b/cmake/vars.sh.in index c77f71eab..ba9d5bc54 100644 --- a/cmake/vars.sh.in +++ b/cmake/vars.sh.in @@ -1,3 +1,4 @@ +#!/bin/bash # # Copyright 2016-2020 Intel Corporation # @@ -13,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -#!/bin/bash get_script_path() ( script="$1" diff --git a/doc/rst/source/device_communication.rst b/doc/rst/source/device_communication.rst index 842bfa502..4e8ad3c0d 100644 --- a/doc/rst/source/device_communication.rst +++ b/doc/rst/source/device_communication.rst @@ -17,7 +17,7 @@ Consider a simple oneCCL ``allreduce`` example for GPU: .. code:: cpp - auto ccl_device_context = ccl::create_context(sycl_context); + auto ccl_context = ccl::create_context(sycl_context); auto ccl_device = ccl::create_device(sycl_device); auto comms = ccl::create_communicators( @@ -56,9 +56,9 @@ Consider a simple oneCCL ``allreduce`` example for GPU: .. code:: cpp /* using SYCL buffer and accessor */ - sycl_queue.submit([&](cl::sycl::handler& cgh) { - auto send_buf_dev_acc = send_buf.get_access(cgh); - cgh.parallel_for(range<1>{elem_count}, [=](item<1> idx) { + sycl_queue.submit([&](cl::sycl::handler& h) { + auto send_buf_dev_acc = send_buf.get_access(h); + h.parallel_for(range<1>{elem_count}, [=](item<1> idx) { send_buf_dev_acc[idx] += 1; }); }); @@ -100,9 +100,9 @@ Consider a simple oneCCL ``allreduce`` example for GPU: auto comm_size = comm.size(); auto expected = comm_size * (comm_size + 1) / 2; - sycl_queue.submit([&](handler& cgh) { - auto recv_buf_dev_acc = recv_buf.get_access(cgh); - cgh.parallel_for(range<1>{elem_count}, [=](item<1> idx) { + sycl_queue.submit([&](handler& h) { + auto recv_buf_dev_acc = recv_buf.get_access(h); + h.parallel_for(range<1>{elem_count}, [=](item<1> idx) { if (recv_buf_dev_acc[idx] != expected) { recv_buf_dev_acc[idx] = -1; } diff --git a/doc/rst/source/env_variables.rst b/doc/rst/source/env_variables.rst index ec6374891..e5cf0f7d4 100755 --- a/doc/rst/source/env_variables.rst +++ b/doc/rst/source/env_variables.rst @@ -534,3 +534,28 @@ CCL_LOG_LEVEL **Description** Set this environment variable to control logging level. + + +CCL_MAX_SHORT_SIZE +################## +**Syntax** + +:: + + CCL_MAX_SHORT_SIZE= + +**Arguments** + +.. list-table:: + :widths: 25 50 + :header-rows: 1 + :align: left + + * - + - Description + * - ``SIZE`` + - Bytes threshold for a collective operation (``0`` if not specified). If the size of a communication buffer in bytes is less than or equal to ``SIZE``, then |product_short| does not split operation between workers. Applicable for ``allreduce``, ``reduce`` and ``broadcast``. + +**Description** + +Set this environment variable to specify the threshold of the number of bytes for a collective operation to be split. diff --git a/doc/rst/source/index.rst b/doc/rst/source/index.rst index 63e525be9..5c3430aa0 100755 --- a/doc/rst/source/index.rst +++ b/doc/rst/source/index.rst @@ -12,7 +12,7 @@ - Optimized to drive scalability of communication patterns by allowing to easily trade-off compute for communication performance. - Enables a set of DL-specific optimizations, such as prioritization, persistent operations, or out-of-order execution. - Works across various interconnects: Intel(R) Omni-Path Architecture, InfiniBand*, and Ethernet. -- Provides common API sufficient to support communication workflows within Deep Learning frameworks (such as PyTorch*, Horovod*). +- Provides common API sufficient to support communication workflows within Deep Learning / distributed frameworks (such as PyTorch*, Horovod*). |product_short| package comprises the |product_short| Software Development Kit (SDK) and the Intel(R) MPI Library Runtime components. @@ -34,6 +34,7 @@ Contents: specification.rst host_communication.rst device_communication.rst + limitations.rst .. toctree:: :maxdepth: 1 diff --git a/doc/rst/source/installation.rst b/doc/rst/source/installation.rst index 987a820e8..03faf19f0 100755 --- a/doc/rst/source/installation.rst +++ b/doc/rst/source/installation.rst @@ -10,11 +10,7 @@ Installation This page explains how to install and configure the |product_full| (|product_short|). -|product_short| supports different installation scenarios: - -* `Installation using command line interface`_ -* `Installation using tar.gz`_ -* `Installation using RPM`_ +|product_short| supports different installation scenarios using command line interface. .. note:: Visit |sys_req|_ to learn about hardware and software requirements for |product_short|. @@ -64,93 +60,30 @@ You can customize CLI-based installation (for example, specify directory, compil :: - cmake .. -DCMAKE_INSTALL_PREFIX=/path/to/installation/directory + cmake .. -DCMAKE_INSTALL_PREFIX= - If no ``-DCMAKE_INSTALL_PREFIX`` is specified, |product_short| is installed into the ``_install`` subdirectory of the current build directory. - For example, ``ccl/build/_install``. + If no ``-DCMAKE_INSTALL_PREFIX`` is specified, |product_short| is installed into the ``_install`` subdirectory of the current build directory. For example, ``ccl/build/_install``. * To specify **compiler**, modify the ``cmake`` command: :: - cmake .. -DCMAKE_C_COMPILER=your_c_compiler -DCMAKE_CXX_COMPILER=your_cxx_compiler - - If ``CMAKE_CXX_COMPILER`` requires ``SYCL`` cross-platform abstraction level it should be specified in ``-DCOMPUTE_RUNTIME`` ( ``compute++`` and ``dpcpp`` supported only): - - :: - - cmake .. -DCMAKE_C_COMPILER=your_c_compiler -DCMAKE_CXX_COMPILER=compute++ -DCOMPUTE_RUNTIME=computecpp - cmake .. -DCMAKE_C_COMPILER=your_c_compiler -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_RUNTIME=dpcpp + cmake .. -DCMAKE_C_COMPILER= -DCMAKE_CXX_COMPILER= - OpenCL search location path hint can be specified by using standart environment ``OPENCLROOT`` additionally: + To enable ``SYCL`` devices communication support specify ``SYCL`` compiler and set ``-DCOMPUTE_RUNTIME`` (DPC++ supported only): :: - OPENCLROOT=your_opencl_location cmake .. -DCMAKE_C_COMPILER=your_c_compiler -DCMAKE_CXX_COMPILER=compute++ -DCOMPUTE_RUNTIME=computecpp - + cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_RUNTIME=dpcpp * To specify the **build type**, modify the ``cmake`` command: :: - cmake .. -DCMAKE_BUILD_TYPE=[Debug|Release|RelWithDebInfo|MinSizeRel] + cmake .. -DCMAKE_BUILD_TYPE=[Debug|Release] * To enable ``make`` verbose output to see all parameters used by ``make`` during compilation and linkage, modify the ``make`` command as follows: :: - make -j VERBOSE=1 - -* To archive installed files: - - :: - - make -j install - -* To build with Address Sanitizer, modify the ``cmake`` command as follow: - - :: - - cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_ASAN=true - - Make sure that ``libasan.so`` exists. - - .. note:: - - Address sanitizer only works in the debug build. - -Binary releases are available on our release page. - -Installation using tar.gz -************************* - -To install |product_short| using the |tgz_file|_ in a user mode, execute the following commands: - -.. prompt:: bash - - tar zxf l_ccl-devel-64-...tgz - cd l_ccl_.. - ./install.sh - -There is no uninstall script. To uninstall |product_short|, delete the whole installation directory. - -Installation using RPM -********************** - -You can get |product_short| through the RPM Package Manager. To install the library in a root mode using RPM, follow these steps: - -#. Log in as root. - -#. Install the following package: - - .. prompt:: bash - - rpm -i intel-ccl-devel-64-.-.x86_64.rpm - - where ``.-`` is a string. For example, ``2017.0-009``. - -To uninstall |product_short| using the RPM Package Manager, execute this command: - - .. prompt:: bash - - rpm -e intel-ccl-devel-64-.-.x86_64 + make -j VERBOSE=1 install diff --git a/doc/rst/source/limitations.rst b/doc/rst/source/limitations.rst new file mode 100644 index 000000000..004e9e53f --- /dev/null +++ b/doc/rst/source/limitations.rst @@ -0,0 +1,9 @@ +=========== +Limitations +=========== + +The list of scenarious not yet supported by oneCCL: + +- Creation of multiple ranks within single process +- Handling of dependencies as operation parameter (for example, ``deps`` vector in ``ccl::allreduce(..., deps)``) +- Float16 datatype support diff --git a/doc/rst/source/sample.rst b/doc/rst/source/sample.rst index e69e04958..f1db10d70 100755 --- a/doc/rst/source/sample.rst +++ b/doc/rst/source/sample.rst @@ -25,15 +25,23 @@ The sample code below shows how to use |product_short| API to perform allreduce MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_rank(MPI_COMM_WORLD, &rank); + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { - MPI_Finalize(); + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } - /* allocate USM buffers */ - auto send_buf = aligned_alloc_shared(64, count, q); - auto recv_buf = aligned_alloc_shared(64, count, q); + buf_allocator allocator(q); + + auto usm_alloc_type = usm::alloc::shared; + if (argc > 2) { + usm_alloc_type = usm_alloc_type_from_string(argv[2]); + } + + if (!check_sycl_usm(q, usm_alloc_type)) { + return -1; + } /* create kvs */ ccl::shared_ptr_class kvs; @@ -56,18 +64,15 @@ The sample code below shows how to use |product_short| API to perform allreduce /* create stream */ auto stream = ccl::create_stream(q); - { - /* open buffers and initialize them on the host side */ - for (i = 0; i < count; i++) { - send_buf[i] = rank; - recv_buf[i] = -1; - } - } + /* create buffers */ + auto send_buf = allocator.allocate(count, usm_alloc_type); + auto recv_buf = allocator.allocate(count, usm_alloc_type); - /* open send_buf and modify it on the device side */ - q.submit([&](auto &h) { + /* open buffers and modify them on the device side */ + auto e = q.submit([&](auto &h) { h.parallel_for(count, [=](auto id) { - send_buf[id] += 1; + send_buf[id] = rank + 1; + recv_buf[id] = -1; }); }); @@ -105,21 +110,14 @@ The sample code below shows how to use |product_short| API to perform allreduce } } - free(send_buf, q); - free(recv_buf, q); - - MPI_Finalize(); - return 0; } - - Build details ************* -#. |product_short| should be built with SYCL* support. +#. |product_short| should be built with ``SYCL`` support (DPC++ supported only). #. Set up the library environment (see :doc:`prerequisites`). diff --git a/doc/rst/source/sparse_collectives.rst b/doc/rst/source/sparse_collectives.rst index 4e4440bcf..5c30baf9e 100755 --- a/doc/rst/source/sparse_collectives.rst +++ b/doc/rst/source/sparse_collectives.rst @@ -84,13 +84,13 @@ Completion callback should follow the signature: typedef void (*completion_fn) ( - const void*, // idx_buf - size_t, // idx_count - ccl::datatype, // idx_dtype - const void*, // val_buf - size_t, // val_count - ccl::datatype, // val_dtype - const void* // user_context + const void*, /* idx_buf */ + size_t, /* idx_count */ + ccl::datatype, /* idx_dtype */ + const void*, /* val_buf */ + size_t, /* val_count */ + ccl::datatype, /* val_dtype */ + const void* /* user_context */ ); Note that ``idx_buf`` and ``val_buf`` are temporary buffers. @@ -103,18 +103,14 @@ Allocation callback should follow the signature: typedef void (*alloc_fn) ( - size_t, // idx_count - ccl::datatype, // idx_dtype - size_t, // val_count - ccl::datatype, // val_dtype - const void*, // user_context - void**, // out_idx_buf - void** // out_val_buf + size_t, /* idx_count */ + ccl::datatype, /* idx_dtype */ + size_t, /* val_count */ + ccl::datatype, /* val_dtype */ + const void*, /* user_context */ + void**, /* out_idx_buf */ + void** /* out_val_buf */ ); - -For more details, refer to `this example `_. - - .. note:: WARNING: ``ccl::sparse_allreduce`` is experimental and subject to change. diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt index 71f32028d..24ce78b4d 100644 --- a/examples/benchmark/CMakeLists.txt +++ b/examples/benchmark/CMakeLists.txt @@ -15,6 +15,11 @@ # file(GLOB sources "./src/*.c" "./src/*.cpp") +if (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD_REQUIRED ON) +endif() + include_directories(include) include_directories(src) @@ -29,5 +34,5 @@ foreach(src ${sources}) target_link_libraries(${executable} PUBLIC dl) target_link_libraries(${executable} PRIVATE m) target_link_libraries(${executable} PUBLIC mpi) - install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/benchmark) + install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/benchmark OPTIONAL) endforeach() diff --git a/examples/benchmark/include/benchmark.hpp b/examples/benchmark/include/benchmark.hpp index 5a5fc9a09..8525b498c 100644 --- a/examples/benchmark/include/benchmark.hpp +++ b/examples/benchmark/include/benchmark.hpp @@ -36,134 +36,52 @@ using namespace cl::sycl; using namespace cl::sycl::access; #endif /* CCL_ENABLE_SYCL */ +#include "base.hpp" #include "base_utils.hpp" #include "bf16.hpp" #include "coll.hpp" #include "sparse_allreduce/sparse_detail.hpp" -/* specific benchmark variables */ -// TODO: add ccl::bf16 -constexpr std::initializer_list all_dtypes = { - ccl::datatype::int8, ccl::datatype::int32, ccl::datatype::float32, - ccl::datatype::float64, ccl::datatype::int64, ccl::datatype::uint64 -}; - -/* specific benchmark defines */ - -#define PRINT(fmt, ...) printf(fmt "\n", ##__VA_ARGS__); - -#ifndef PRINT_BY_ROOT -#define PRINT_BY_ROOT(comm, fmt, ...) \ - if (comm.rank() == 0) { \ - printf(fmt "\n", ##__VA_ARGS__); \ - } -#endif //PRINT_BY_ROOT - -#define ASSERT(cond, fmt, ...) \ - do { \ - if (!(cond)) { \ - printf("FAILED\n"); \ - fprintf(stderr, "ASSERT '%s' FAILED " fmt "\n", #cond, ##__VA_ARGS__); \ - throw std::runtime_error("ASSERT FAILED"); \ - } \ - } while (0) - -typedef enum { BACKEND_HOST, BACKEND_SYCL } backend_type_t; -typedef enum { LOOP_REGULAR, LOOP_UNORDERED } loop_type_t; -typedef enum { BUF_SINGLE, BUF_MULTI } buf_type_t; - -#define DEFAULT_BACKEND BACKEND_HOST -#define DEFAULT_LOOP LOOP_REGULAR -#define DEFAULT_BUF BUF_SINGLE - -std::map backend_names = { - std::make_pair(BACKEND_HOST, "host"), - std::make_pair(BACKEND_SYCL, "sycl") -}; - -std::map loop_names = { std::make_pair(LOOP_REGULAR, "regular"), - std::make_pair(LOOP_UNORDERED, "unordered") }; - -std::map buf_names = { std::make_pair(BUF_MULTI, "multi"), - std::make_pair(BUF_SINGLE, "single") }; - -// TODO: add ccl::bf16 -std::map dtype_names = { - std::make_pair(ccl::datatype::int8, "char"), - std::make_pair(ccl::datatype::int32, "int"), - std::make_pair(ccl::datatype::float32, "float"), - std::make_pair(ccl::datatype::float64, "double"), - std::make_pair(ccl::datatype::int64, "int64"), - std::make_pair(ccl::datatype::uint64, "uint64"), -}; - -std::map reduction_names = { - std::make_pair(ccl::reduction::sum, "sum"), - std::make_pair(ccl::reduction::prod, "prod"), - std::make_pair(ccl::reduction::min, "min"), - std::make_pair(ccl::reduction::max, "max"), -}; - -// variables for setting dtypes to launch benchmark -// TODO: add ccl::bf16 -template -using checked_dtype_t = std::pair; -using supported_dtypes_t = std::tuple, - checked_dtype_t, - checked_dtype_t, - checked_dtype_t, - checked_dtype_t, - checked_dtype_t>; -supported_dtypes_t launch_dtypes; - -/* specific benchmark functions */ void print_help_usage(const char* app) { - PRINT( - "\nUSAGE:\n" - "\t%s [OPTIONS]\n\n" - "OPTIONS:\n" - "\t[-b,--backend ]: %s\n" - "\t[-e,--loop ]: %s\n" - "\t[-l,--coll ]: %s\n" - "\t[-i,--iters ]: %d\n" - "\t[-w,--warmup_iters ]: %d\n" - "\t[-p,--buf_count ]: %d\n" - "\t[-f,--min_elem_count ]: %d\n" - "\t[-t,--max_elem_count ]: %d\n" - "\t[-c,--check ]: %d\n" - "\t[-v,--v2i_ratio ]: %d\n" - "\t[-d,--dtype ]: %s\n" - "\t[-r,--reduction ]: %s\n" - "\t[-n,--buf_type ]: %s\n" - "\t[-o,--csv_filepath ]: %s\n" - "\t[-h,--help]\n\n" - "example:\n\t--coll allgatherv,allreduce,sparse_allreduce,sparse_allreduce_bf16 --backend host --loop regular\n" - "example:\n\t--coll bcast,reduce --backend sycl --loop unordered \n", - app, - backend_names[DEFAULT_BACKEND].c_str(), - loop_names[DEFAULT_LOOP].c_str(), - DEFAULT_COLL_LIST, - DEFAULT_ITERS, - DEFAULT_WARMUP_ITERS, - DEFAULT_BUF_COUNT, - DEFAULT_MIN_ELEM_COUNT, - DEFAULT_MAX_ELEM_COUNT, - DEFAULT_CHECK_VALUES, - DEFAULT_V2I_RATIO, - DEFAULT_DTYPES_LIST, - DEFAULT_REDUCTIONS_LIST, - buf_names[DEFAULT_BUF_TYPE].c_str(), - DEFAULT_CSV_FILEPATH); -} - -std::list tokenize(const std::string& input, char delimeter) { - std::stringstream ss(input); - std::list ret; - std::string value; - while (std::getline(ss, value, delimeter)) { - ret.push_back(value); - } - return ret; + PRINT("\nUSAGE:\n" + "\t%s [OPTIONS]\n\n" + "OPTIONS:\n" + "\t[-b,--backend ]: %s\n" + "\t[-e,--loop ]: %s\n" + "\t[-i,--iters ]: %d\n" + "\t[-w,--warmup_iters ]: %d\n" + "\t[-p,--buf_count ]: %d\n" + "\t[-f,--min_elem_count ]: %d\n" + "\t[-t,--max_elem_count ]: %d\n" + "\t[-c,--check ]: %d\n" + "\t[-a,--sycl_dev_type ]: %s\n" + "\t[-m,--sycl_mem_type ]: %s\n" + "\t[-u,--sycl_usm_type ]: %s\n" + "\t[-k,--ranks_per_proc ]: %d\n" + "\t[-l,--coll ]: %s\n" + "\t[-d,--dtype ]: %s\n" + "\t[-r,--reduction ]: %s\n" + "\t[-o,--csv_filepath ]: %s\n" + "\t[-h,--help]\n\n" + "example:\n\t--coll allgatherv,allreduce --backend host --loop regular\n" + "example:\n\t--coll bcast,reduce --backend sycl --loop unordered \n", + app, + backend_names[DEFAULT_BACKEND].c_str(), + loop_names[DEFAULT_LOOP].c_str(), + DEFAULT_ITERS, + DEFAULT_WARMUP_ITERS, + DEFAULT_BUF_COUNT, + DEFAULT_MIN_ELEM_COUNT, + DEFAULT_MAX_ELEM_COUNT, + DEFAULT_CHECK_VALUES, + sycl_dev_names[DEFAULT_SYCL_DEV_TYPE].c_str(), + sycl_mem_names[DEFAULT_SYCL_MEM_TYPE].c_str(), + sycl_usm_names[DEFAULT_SYCL_USM_TYPE].c_str(), + DEFAULT_RANKS_PER_PROC, + DEFAULT_COLL_LIST, + DEFAULT_DTYPES_LIST, + DEFAULT_REDUCTIONS_LIST, + DEFAULT_CSV_FILEPATH); } template @@ -237,95 +155,109 @@ int set_loop(const std::string& option_value, loop_type_t& loop) { return 0; } -int set_buf_type(const std::string& option_value, buf_type_t& buf) { - std::string option_name = "buf_type"; - std::set supported_option_values{ buf_names[BUF_SINGLE], buf_names[BUF_MULTI] }; +int set_sycl_dev_type(const std::string& option_value, sycl_dev_type_t& dev) { + std::string option_name = "sycl_dev_type"; + std::set supported_option_values{ sycl_dev_names[SYCL_DEV_HOST], + sycl_dev_names[SYCL_DEV_CPU], + sycl_dev_names[SYCL_DEV_GPU] }; + + if (check_supported_options(option_name, option_value, supported_option_values)) + return -1; + + if (option_value == sycl_dev_names[SYCL_DEV_HOST]) + dev = SYCL_DEV_HOST; + else if (option_value == sycl_dev_names[SYCL_DEV_CPU]) + dev = SYCL_DEV_CPU; + else if (option_value == sycl_dev_names[SYCL_DEV_GPU]) + dev = SYCL_DEV_GPU; + + return 0; +} + +int set_sycl_mem_type(const std::string& option_value, sycl_mem_type_t& mem) { + std::string option_name = "sycl_mem_type"; + std::set supported_option_values{ sycl_mem_names[SYCL_MEM_USM], + sycl_mem_names[SYCL_MEM_BUF] }; if (check_supported_options(option_name, option_value, supported_option_values)) return -1; - buf = (option_value == buf_names[BUF_SINGLE]) ? BUF_SINGLE : BUF_MULTI; + mem = (option_value == sycl_mem_names[SYCL_MEM_USM]) ? SYCL_MEM_USM : SYCL_MEM_BUF; return 0; } -// leave this dtype here because of tokenize() call -typedef struct user_options_t { - backend_type_t backend; - loop_type_t loop; - size_t iters; - size_t warmup_iters; - size_t buf_count; - size_t min_elem_count; - size_t max_elem_count; - int check_values; - buf_type_t buf_type; - size_t v2i_ratio; - std::list coll_names; - std::list dtypes; - std::list reductions; - std::string csv_filepath; - - user_options_t() { - backend = DEFAULT_BACKEND; - loop = DEFAULT_LOOP; - coll_names = tokenize(DEFAULT_COLL_LIST, ','); - iters = DEFAULT_ITERS; - warmup_iters = DEFAULT_WARMUP_ITERS; - buf_count = DEFAULT_BUF_COUNT; - min_elem_count = DEFAULT_MIN_ELEM_COUNT; - max_elem_count = DEFAULT_MAX_ELEM_COUNT; - check_values = DEFAULT_CHECK_VALUES; - buf_type = DEFAULT_BUF_TYPE; - v2i_ratio = DEFAULT_V2I_RATIO; - dtypes = tokenize(DEFAULT_DTYPES_LIST, ','); - reductions = tokenize(DEFAULT_REDUCTIONS_LIST, ','); - csv_filepath = std::string(DEFAULT_CSV_FILEPATH); +int set_sycl_usm_type(const std::string& option_value, sycl_usm_type_t& usm) { + std::string option_name = "sycl_usm_type"; + std::set supported_option_values{ sycl_usm_names[SYCL_USM_SHARED], + sycl_usm_names[SYCL_USM_DEVICE] }; + + if (check_supported_options(option_name, option_value, supported_option_values)) + return -1; + + usm = (option_value == sycl_usm_names[SYCL_USM_SHARED]) ? SYCL_USM_SHARED : SYCL_USM_DEVICE; + + return 0; +} + +size_t get_iter_count(size_t bytes, size_t max_iter_count) { + size_t n, res = max_iter_count; + n = bytes >> 18; + while (n) { + res >>= 1; + n >>= 1; } -} user_options_t; -/* placing print_timings() here is because of declaration of user_options_t */ -// FIXME FS: what? -void print_timings(const ccl::communicator& comm, - const std::vector& timer, + if (!res && max_iter_count) + res = 1; + + return res; +} + +/* timer array contains one number per collective, one collective corresponds to rank_per_proc ranks */ +void print_timings(ccl::communicator& comm, + const std::vector& local_timers, const user_options_t& options, const size_t elem_count, + const size_t iter_count, ccl::datatype dtype, ccl::reduction op) { - const size_t buf_count = options.buf_type == BUF_SINGLE ? 1 : options.buf_count; + const size_t buf_count = options.buf_count; const size_t ncolls = options.coll_names.size(); std::vector all_timers(ncolls * comm.size()); std::vector recv_counts(comm.size()); - size_t idx; + int idx; for (idx = 0; idx < comm.size(); idx++) recv_counts[idx] = ncolls; - ccl::allgatherv(timer.data(), ncolls, all_timers.data(), recv_counts, comm).wait(); + ccl::allgatherv(local_timers.data(), ncolls, all_timers.data(), recv_counts, comm).wait(); if (comm.rank() == 0) { std::vector timers(comm.size(), 0); - for (size_t r = 0; r < comm.size(); ++r) { + for (int r = 0; r < comm.size(); ++r) { for (size_t c = 0; c < ncolls; ++c) { timers[r] += all_timers[r * ncolls + c]; } } + double avg_timer(0); double avg_timer_per_buf(0); for (idx = 0; idx < comm.size(); idx++) { avg_timer += timers[idx]; } - avg_timer /= (options.iters * comm.size()); + avg_timer /= (iter_count * comm.size()); avg_timer_per_buf = avg_timer / buf_count; double stddev_timer = 0; double sum = 0; for (idx = 0; idx < comm.size(); idx++) { - double val = timers[idx] / options.iters; + double val = timers[idx] / iter_count; sum += (val - avg_timer) * (val - avg_timer); } + stddev_timer = sqrt(sum / comm.size()) / avg_timer * 100; - if (options.buf_type == BUF_SINGLE) { + if (buf_count == 1) { printf("%10zu %12.2lf %11.1lf\n", elem_count * ccl::get_datatype_size(dtype) * buf_count, avg_timer, @@ -345,57 +277,41 @@ void print_timings(const ccl::communicator& comm, if (!options.csv_filepath.empty()) { std::ofstream csvf; csvf.open(options.csv_filepath, std::ios::app); + if (csvf.is_open()) { std::vector avg_timer(ncolls, 0); - for (size_t r = 0; r < comm.size(); ++r) { + + for (int r = 0; r < comm.size(); ++r) { for (size_t c = 0; c < ncolls; ++c) { avg_timer[c] += all_timers[r * ncolls + c]; } } + for (size_t c = 0; c < ncolls; ++c) { - avg_timer[c] /= (options.iters * comm.size()); + avg_timer[c] /= (iter_count * comm.size()); } - int idx = 0; + int i = 0; for (auto cop = options.coll_names.begin(); cop != options.coll_names.end(); - ++cop, ++idx) { + ++cop, ++i) { csvf << comm.size() << "," << (*cop) << "," << reduction_names[op] << "," - << dtype_names[dtype] << "," - << ccl::get_datatype_size(dtype) << "," - << elem_count << "," << buf_count << "," << avg_timer[idx] << std::endl; + << dtype_names[dtype] << "," << ccl::get_datatype_size(dtype) << "," + << elem_count << "," << buf_count << "," << avg_timer[i] << std::endl; } csvf.close(); } } } + ccl::barrier(comm); } -/* specific benchmark functors */ -class set_dtypes_func { -private: - const std::list& dtypes; - -public: - set_dtypes_func(const std::list& dtypes) : dtypes(dtypes) {} - - template - void operator()(checked_dtype_t& val) { - auto it = std::find(dtypes.begin(), dtypes.end(), ccl::native_type_info::name()); - if (it != std::end(dtypes)) { - val.first = true; - } - } -}; - -int parse_user_options(int& argc, - char**(&argv), - user_options_t& options) { +int parse_user_options(int& argc, char**(&argv), user_options_t& options) { int ch; int errors = 0; - // values needed by getopt - const char* const short_options = "b:e:i:w:p:f:t:c:v:l:d:r:n:o:h:"; + const char* const short_options = "b:e:i:w:p:f:t:c:v:o:a:m:u:k:l:d:r:h"; + struct option getopt_options[] = { { "backend", required_argument, 0, 'b' }, { "loop", required_argument, 0, 'e' }, @@ -406,10 +322,13 @@ int parse_user_options(int& argc, { "max_elem_count", required_argument, 0, 't' }, { "check", required_argument, 0, 'c' }, { "v2i_ratio", required_argument, 0, 'v' }, + { "sycl_dev_type", required_argument, 0, 'a' }, + { "sycl_mem_type", required_argument, 0, 'm' }, + { "sycl_usm_type", required_argument, 0, 'u' }, + { "ranks", required_argument, 0, 'k' }, { "coll", required_argument, 0, 'l' }, { "dtype", required_argument, 0, 'd' }, { "reduction", required_argument, 0, 'r' }, - { "buf_type", required_argument, 0, 'n' }, { "csv_filepath", required_argument, 0, 'o' }, { "help", no_argument, 0, 'h' }, { 0, 0, 0, 0 } // required at end of array. @@ -418,12 +337,16 @@ int parse_user_options(int& argc, while ((ch = getopt_long(argc, argv, short_options, getopt_options, NULL)) != -1) { switch (ch) { case 'b': - if (set_backend(optarg, options.backend)) + if (set_backend(optarg, options.backend)) { + PRINT("failed to parse 'backend' option"); errors++; + } break; case 'e': - if (set_loop(optarg, options.loop)) + if (set_loop(optarg, options.loop)) { + PRINT("failed to parse 'loop' option"); errors++; + } break; case 'i': options.iters = atoll(optarg); break; case 'w': options.warmup_iters = atoll(optarg); break; @@ -432,7 +355,32 @@ int parse_user_options(int& argc, case 't': options.max_elem_count = atoll(optarg); break; case 'c': options.check_values = atoi(optarg); break; case 'v': options.v2i_ratio = atoll(optarg); break; - case 'l': options.coll_names = tokenize(optarg, ','); break; + case 'a': + if (set_sycl_dev_type(optarg, options.sycl_dev_type)) { + PRINT("failed to parse 'sycl_dev_type' option"); + errors++; + } + break; + case 'm': + if (set_sycl_mem_type(optarg, options.sycl_mem_type)) { + PRINT("failed to parse 'sycl_mem_type' option"); + errors++; + } + break; + case 'u': + if (set_sycl_usm_type(optarg, options.sycl_usm_type)) { + PRINT("failed to parse 'sycl_usm_type' option"); + errors++; + } + break; + case 'k': options.ranks_per_proc = atoll(optarg); break; + case 'l': + if (strcmp("all", optarg) == 0) { + options.coll_names = tokenize(ALL_COLLS_LIST, ','); + } + else + options.coll_names = tokenize(optarg, ','); + break; case 'd': if (strcmp("all", optarg) == 0) { options.dtypes = tokenize(ALL_DTYPES_LIST, ','); @@ -447,13 +395,12 @@ int parse_user_options(int& argc, else options.reductions = tokenize(optarg, ','); break; - case 'n': - if (set_buf_type(optarg, options.buf_type)) - errors++; - break; case 'o': options.csv_filepath = std::string(optarg); break; - case 'h': print_help_usage(argv[0]); return -1; - default: errors++; break; + case 'h': return -1; + default: + PRINT("failed to parse unknown option"); + errors++; + break; } } @@ -464,14 +411,23 @@ int parse_user_options(int& argc, if (errors > 0) { PRINT("found %d errors while parsing user options", errors); + for (int idx = 0; idx < argc; idx++) { + PRINT("arg %d: %s", idx, argv[idx]); + } return -1; } + /* adjust user options */ + if (!options.min_elem_count) + options.min_elem_count = 1; + + if (options.max_elem_count < options.min_elem_count) + options.max_elem_count = options.min_elem_count; + return 0; } -void print_user_options(const user_options_t& options, - const ccl::communicator& comm) { +void print_user_options(const user_options_t& options, const ccl::communicator& comm) { std::stringstream ss; ss << "colls: "; std::copy(options.coll_names.begin(), @@ -487,11 +443,14 @@ void print_user_options(const user_options_t& options, std::string backend_str = find_str_val(backend_names, options.backend); std::string loop_str = find_str_val(loop_names, options.loop); - std::string buf_type_str = find_str_val(buf_names, options.buf_type); + + std::string sycl_dev_type_str = find_str_val(sycl_dev_names, options.sycl_dev_type); + std::string sycl_mem_type_str = find_str_val(sycl_mem_names, options.sycl_mem_type); + std::string sycl_usm_type_str = find_str_val(sycl_usm_names, options.sycl_usm_type); PRINT_BY_ROOT(comm, "options:" - "\n ranks: %zu" + "\n processes: %d" "\n backend: %s" "\n loop: %s" "\n iters: %zu" @@ -500,8 +459,11 @@ void print_user_options(const user_options_t& options, "\n min_elem_count: %zu" "\n max_elem_count: %zu" "\n check: %d" - "\n buf_type: %s" "\n v2i_ratio: %zu" + "\n sycl_dev_type: %s" + "\n sycl_mem_type: %s" + "\n sycl_usm_type: %s" + "\n ranks_per_proc: %zu" "\n %s" "\n csv_filepath: %s", comm.size(), @@ -513,8 +475,11 @@ void print_user_options(const user_options_t& options, options.min_elem_count, options.max_elem_count, options.check_values, - buf_type_str.c_str(), options.v2i_ratio, + sycl_dev_type_str.c_str(), + sycl_mem_type_str.c_str(), + sycl_usm_type_str.c_str(), + options.ranks_per_proc, ss.str().c_str(), options.csv_filepath.c_str()); } diff --git a/examples/benchmark/include/coll.hpp b/examples/benchmark/include/coll.hpp index 8959c141e..975c6923b 100644 --- a/examples/benchmark/include/coll.hpp +++ b/examples/benchmark/include/coll.hpp @@ -15,22 +15,23 @@ */ #pragma once -#include "base.hpp" #include "config.hpp" +#include "transport.hpp" +#include "types.hpp" #ifdef CCL_ENABLE_SYCL -#include "sycl_base.hpp" template using sycl_buffer_t = cl::sycl::buffer; #endif +#define COLL_ROOT (0) + struct base_coll; using coll_list_t = std::vector>; using req_list_t = std::vector; typedef struct bench_exec_attr { - bench_exec_attr() = default; template struct setter { @@ -44,8 +45,7 @@ typedef struct bench_exec_attr { struct factory { template void operator()(ccl::shared_ptr_class& attr) { - attr = std::make_shared( - ccl::create_operation_attr()); + attr = std::make_shared(ccl::create_operation_attr()); } }; @@ -55,8 +55,8 @@ typedef struct bench_exec_attr { ccl::shared_ptr_class, ccl::shared_ptr_class, ccl::shared_ptr_class, - ccl::shared_ptr_class, - ccl::shared_ptr_class>; + ccl::shared_ptr_class/*, + ccl::shared_ptr_class*/>; template attr_t& get_attr() { @@ -69,8 +69,8 @@ typedef struct bench_exec_attr { } template - typename ccl::details::ccl_api_type_attr_traits::return_type - set(const Value& v) { + typename ccl::detail::ccl_api_type_attr_traits::return_type set( + const Value& v) { ccl_tuple_for_each(coll_attrs, setter(v)); return v; } @@ -88,6 +88,9 @@ typedef struct bench_exec_attr { typedef struct bench_init_attr { size_t buf_count; size_t max_elem_count; + size_t ranks_per_proc; + sycl_mem_type_t sycl_mem_type; + sycl_usm_type_t sycl_usm_type; size_t v2i_ratio; } bench_init_attr; @@ -96,6 +99,11 @@ struct base_coll { base_coll(bench_init_attr init_attr) : init_attr(init_attr) { send_bufs.resize(init_attr.buf_count); recv_bufs.resize(init_attr.buf_count); + + for (size_t idx = 0; idx < init_attr.buf_count; idx++) { + send_bufs[idx].resize(init_attr.ranks_per_proc); + recv_bufs[idx].resize(init_attr.ranks_per_proc); + } } base_coll() = delete; @@ -105,99 +113,74 @@ struct base_coll { return nullptr; }; - virtual void prepare(size_t elem_count){}; - virtual void finalize(size_t elem_count){}; + virtual void prepare(size_t elem_count) { + auto& transport = transport_data::instance(); + auto& comms = transport.get_comms(); + auto streams = transport.get_bench_streams(); + size_t ranks_per_proc = base_coll::get_ranks_per_proc(); + + for (size_t rank_idx = 0; rank_idx < ranks_per_proc; rank_idx++) { + prepare_internal(elem_count, comms[rank_idx], streams[rank_idx], rank_idx); + } + } + + virtual void finalize(size_t elem_count) { + auto& transport = transport_data::instance(); + auto& comms = transport.get_comms(); + auto streams = transport.get_bench_streams(); + size_t ranks_per_proc = base_coll::get_ranks_per_proc(); + + for (size_t rank_idx = 0; rank_idx < ranks_per_proc; rank_idx++) { + finalize_internal(elem_count, comms[rank_idx], streams[rank_idx], rank_idx); + } + } + + virtual void prepare_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) = 0; + + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) = 0; virtual ccl::datatype get_dtype() const = 0; + size_t get_dtype_size() const { + return ccl::get_datatype_size(get_dtype()); + } + virtual void start(size_t count, size_t buf_idx, const bench_exec_attr& attr, req_list_t& reqs) = 0; - virtual void start_single(size_t count, const bench_exec_attr& attr, req_list_t& reqs) = 0; - /* to get buf_count from initialized private member */ size_t get_buf_count() const noexcept { return init_attr.buf_count; } + size_t get_max_elem_count() const noexcept { return init_attr.max_elem_count; } - size_t get_single_buf_max_elem_count() const noexcept { - return init_attr.buf_count * init_attr.max_elem_count; - } - - std::vector send_bufs; - std::vector recv_bufs; - - void* single_send_buf = nullptr; - void* single_recv_buf = nullptr; - -private: - bench_init_attr init_attr; -}; - -struct host_data { - static ccl::shared_ptr_class comm_ptr; - static void init(size_t size, size_t rank, ccl::shared_ptr_class kvs) { - - if (comm_ptr) { - throw ccl::exception(std::string(__FUNCTION__) + " - reinit is not allowed"); - } - - comm_ptr = std::make_shared( - ccl::create_communicator(size, rank, kvs)); - } - static void deinit() { - comm_ptr.reset(); + sycl_mem_type_t get_sycl_mem_type() const noexcept { + return init_attr.sycl_mem_type; } -}; - -ccl::shared_ptr_class host_data::comm_ptr{}; - -#ifdef CCL_ENABLE_SYCL -struct device_data { - static ccl::shared_ptr_class comm_ptr; - static ccl::shared_ptr_class stream_ptr; - static cl::sycl::queue sycl_queue; - - static void init(size_t size, - size_t rank, - cl::sycl::device& device, - cl::sycl::context& ctx, - ccl::shared_ptr_class kvs) { - - if (stream_ptr or comm_ptr) { - throw ccl::exception(std::string(__FUNCTION__) + " - reinit is not allowed"); - } - - auto ccl_dev = ccl::create_device(device); - auto ccl_ctx = ccl::create_context(ctx); - - comm_ptr = std::make_shared( - ccl::create_communicator( - size, rank, - ccl_dev, - ccl_ctx, - kvs)); - - sycl_queue = cl::sycl::queue(device); - - stream_ptr = - std::make_shared(ccl::create_stream(sycl_queue)); + sycl_usm_type_t get_sycl_usm_type() const noexcept { + return init_attr.sycl_usm_type; } - static void deinit() { - comm_ptr.reset(); - stream_ptr.reset(); + size_t get_ranks_per_proc() const noexcept { + return init_attr.ranks_per_proc; } -}; -ccl::shared_ptr_class device_data::comm_ptr{}; -ccl::shared_ptr_class device_data::stream_ptr{}; -cl::sycl::queue device_data::sycl_queue{}; + // first dim - per buf_count, second dim - per local rank + std::vector> send_bufs; + std::vector> recv_bufs; -#endif /* CCL_ENABLE_SYCL */ +private: + bench_init_attr init_attr; +}; diff --git a/examples/benchmark/include/config.hpp b/examples/benchmark/include/config.hpp index 3b801a140..9ccad6598 100644 --- a/examples/benchmark/include/config.hpp +++ b/examples/benchmark/include/config.hpp @@ -15,27 +15,28 @@ */ #pragma once -#define ALIGNMENT (4096) -#define DTYPE float +#define ALIGNMENT (4096) +#define DTYPE float -#define ALL_DTYPES_LIST "char,int,float,double,int64_t,uint64_t" +#define ALL_COLLS_LIST "allgatherv,allreduce,alltoall,alltoallv,bcast,reduce,reduce_scatter" +#define ALL_DTYPES_LIST "int8,int32,int64,uint64,float32,float64" #define ALL_REDUCTIONS_LIST "sum,prod,min,max" -#define DEFAULT_BACKEND BACKEND_HOST -#define DEFAULT_LOOP LOOP_REGULAR -#define DEFAULT_COLL_LIST \ - "allgatherv,allreduce,alltoall,alltoallv,bcast,reduce," \ - "reduce_scatter,sparse_allreduce,sparse_allreduce_bf16," \ - "allgatherv,allreduce,alltoall,alltoallv,bcast,reduce," \ - "reduce_scatter,sparse_allreduce,sparse_allreduce_bf16" -#define DEFAULT_ITERS (16) -#define DEFAULT_WARMUP_ITERS (16) -#define DEFAULT_BUF_COUNT (16) -#define DEFAULT_MIN_ELEM_COUNT (1) -#define DEFAULT_MAX_ELEM_COUNT (128) -#define DEFAULT_CHECK_VALUES (1) -#define DEFAULT_BUF_TYPE BUF_MULTI -#define DEFAULT_V2I_RATIO (128) -#define DEFAULT_DTYPES_LIST "float" +#define DEFAULT_BACKEND BACKEND_HOST +#define DEFAULT_LOOP LOOP_REGULAR +#define DEFAULT_ITERS (16) +#define DEFAULT_WARMUP_ITERS (16) +#define DEFAULT_BUF_COUNT (16) +#define DEFAULT_MIN_ELEM_COUNT (1) +#define DEFAULT_MAX_ELEM_COUNT (128) +#define DEFAULT_CHECK_VALUES (1) +#define DEFAULT_V2I_RATIO (128) +#define DEFAULT_SYCL_DEV_TYPE SYCL_DEV_GPU +#define DEFAULT_SYCL_MEM_TYPE SYCL_MEM_USM +#define DEFAULT_SYCL_USM_TYPE SYCL_USM_DEVICE +#define DEFAULT_RANKS_PER_PROC (1) + +#define DEFAULT_COLL_LIST "allreduce" +#define DEFAULT_DTYPES_LIST "float32" #define DEFAULT_REDUCTIONS_LIST "sum" #define DEFAULT_CSV_FILEPATH "" diff --git a/examples/benchmark/include/cpu_coll.hpp b/examples/benchmark/include/cpu_coll.hpp index c064e3a53..dd33cfac4 100644 --- a/examples/benchmark/include/cpu_coll.hpp +++ b/examples/benchmark/include/cpu_coll.hpp @@ -19,48 +19,43 @@ /* cpu-specific base implementation */ template -struct cpu_base_coll : base_coll, protected strategy, host_data { +struct cpu_base_coll : base_coll, protected strategy { using coll_strategy = strategy; template - cpu_base_coll(bench_init_attr init_attr, - size_t sbuf_multiplier, - size_t rbuf_multiplier, - Args&&... args) + cpu_base_coll(bench_init_attr init_attr, Args&&... args) : base_coll(init_attr), - coll_strategy(std::forward(args)...) { + coll_strategy() { int result = 0; - for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { - result = - posix_memalign((void**)&send_bufs[idx], - ALIGNMENT, - base_coll::get_max_elem_count() * sizeof(Dtype) * sbuf_multiplier); - result = - posix_memalign((void**)&recv_bufs[idx], - ALIGNMENT, - base_coll::get_max_elem_count() * sizeof(Dtype) * rbuf_multiplier); + size_t send_multiplier = coll_strategy::get_send_multiplier(); + size_t recv_multiplier = coll_strategy::get_recv_multiplier(); + + for (size_t rank_idx = 0; rank_idx < base_coll::get_ranks_per_proc(); rank_idx++) { + for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + result = posix_memalign( + (void**)&(send_bufs[idx][rank_idx]), + ALIGNMENT, + base_coll::get_max_elem_count() * sizeof(Dtype) * send_multiplier); + result = posix_memalign( + (void**)&(recv_bufs[idx][rank_idx]), + ALIGNMENT, + base_coll::get_max_elem_count() * sizeof(Dtype) * recv_multiplier); + } } - result = posix_memalign( - (void**)&single_send_buf, - ALIGNMENT, - base_coll::get_single_buf_max_elem_count() * sizeof(Dtype) * sbuf_multiplier); - result = posix_memalign( - (void**)&single_recv_buf, - ALIGNMENT, - base_coll::get_single_buf_max_elem_count() * sizeof(Dtype) * rbuf_multiplier); - (void)result; + + ASSERT(result == 0, "failed to allocate buffers"); } cpu_base_coll(bench_init_attr init_attr) : cpu_base_coll(init_attr, 1, 1) {} virtual ~cpu_base_coll() { - for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { - free(send_bufs[idx]); - free(recv_bufs[idx]); + for (size_t rank_idx = 0; rank_idx < base_coll::get_ranks_per_proc(); rank_idx++) { + for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + free(send_bufs[idx][rank_idx]); + free(recv_bufs[idx][rank_idx]); + } } - free(single_send_buf); - free(single_recv_buf); } const char* name() const noexcept override { @@ -71,35 +66,44 @@ struct cpu_base_coll : base_coll, protected strategy, host_data { size_t buf_idx, const bench_exec_attr& attr, req_list_t& reqs) override { - coll_strategy::start_internal(comm(), - count, - static_cast(send_bufs[buf_idx]), - static_cast(recv_bufs[buf_idx]), - attr, - reqs, - coll_strategy::get_op_attr(attr)); - } + auto& transport = transport_data::instance(); + auto& comms = transport.get_comms(); + size_t ranks_per_proc = base_coll::get_ranks_per_proc(); - virtual void start_single(size_t count, - const bench_exec_attr& attr, - req_list_t& reqs) override { - coll_strategy::start_internal(comm(), - count, - static_cast(single_send_buf), - static_cast(single_recv_buf), - attr, - reqs, - coll_strategy::get_op_attr(attr)); + for (size_t rank_idx = 0; rank_idx < ranks_per_proc; rank_idx++) { + coll_strategy::start_internal(comms[rank_idx], + count, + static_cast(send_bufs[buf_idx][rank_idx]), + static_cast(recv_bufs[buf_idx][rank_idx]), + attr, + reqs, + coll_strategy::get_op_attr(attr)); + } } - ccl::datatype get_dtype() const override final { - return ccl::native_type_info::type>::ccl_datatype_value; - } + virtual void prepare_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + int local_rank = comm.rank(); + + size_t send_count = coll_strategy::get_send_multiplier() * elem_count; + size_t recv_count = coll_strategy::get_recv_multiplier() * elem_count; - /* global communicator for all cpu collectives */ - static ccl::communicator& comm() { - if (!host_data::comm_ptr) { + size_t send_bytes = send_count * base_coll::get_dtype_size(); + size_t recv_bytes = recv_count * base_coll::get_dtype_size(); + + std::vector fill_vector(send_count); + std::fill(fill_vector.begin(), fill_vector.end(), local_rank); + + for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { + memcpy(send_bufs[b_idx][rank_idx], fill_vector.data(), send_bytes); + + memset(recv_bufs[b_idx][rank_idx], 0, recv_bytes); } - return *host_data::comm_ptr; + } + + ccl::datatype get_dtype() const override final { + return ccl::native_type_info::type>::dtype; } }; diff --git a/examples/benchmark/include/sycl_coll.hpp b/examples/benchmark/include/sycl_coll.hpp index abb7a9a34..748bbcc3d 100644 --- a/examples/benchmark/include/sycl_coll.hpp +++ b/examples/benchmark/include/sycl_coll.hpp @@ -15,88 +15,83 @@ */ #pragma once +#include +#include +#include +#include + #include "coll.hpp" +#include "sycl_base.hpp" /* from examples/include */ #ifdef CCL_ENABLE_SYCL -cl::sycl::device get_device(const ccl::communicator& comm) { - - // select requested platform by SYCL_BE: L0 or OpenCL - std::vector all_devices = - cl::sycl::device::get_devices(info::device_type::gpu); - std::vector selected_devices; - std::string backend; +#include - if (getenv("SYCL_BE") == nullptr) { - backend = "Level-Zero"; - } - else if (getenv("SYCL_BE") != nullptr) { - if (std::strcmp(getenv("SYCL_BE"), "PI_LEVEL_ZERO") == 0) { - backend = "Level-Zero"; - } - else if (std::strcmp(getenv("SYCL_BE"), "PI_OPENCL") == 0) { - backend = "OpenCL"; - } - else { - throw std::runtime_error("invalid backend: " + std::string(getenv("SYCL_BE"))); - } - } - - for (const auto& device : all_devices) { - auto platform = device.get_platform(); - auto platform_name = platform.get_info(); - std::size_t found = platform_name.find(backend); - if (found != std::string::npos) - selected_devices.push_back(device); - } - - if (selected_devices.size() <= 0) { - throw ccl::exception("no selected device found for SYCL backend: " + backend); - } - - size_t idx = comm.rank() % selected_devices.size(); - auto selected_device = selected_devices[idx]; - std::cout << "\nrunning on: " << selected_device.get_info() - << ", device index: " << idx << "\n"; - - return selected_device; -} +using namespace sycl; +using namespace sycl::access; /* sycl-specific base implementation */ template -struct sycl_base_coll : base_coll, private strategy, device_data { +struct sycl_base_coll : base_coll, private strategy { using coll_strategy = strategy; template - sycl_base_coll(bench_init_attr init_attr, - size_t sbuf_multiplier, - size_t rbuf_multiplier, - Args&&... args) + sycl_base_coll(bench_init_attr init_attr, Args&&... args) : base_coll(init_attr), - coll_strategy(std::forward(args)...) { - for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { - send_bufs[idx] = - new cl::sycl::buffer(base_coll::get_max_elem_count() * sbuf_multiplier); - recv_bufs[idx] = - new cl::sycl::buffer(base_coll::get_max_elem_count() * rbuf_multiplier); + coll_strategy() { + auto& transport = transport_data::instance(); + auto streams = transport.get_bench_streams(); + + size_t send_multiplier = coll_strategy::get_send_multiplier(); + size_t recv_multiplier = coll_strategy::get_recv_multiplier(); + + for (size_t rank_idx = 0; rank_idx < base_coll::get_ranks_per_proc(); rank_idx++) { + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + allocators.push_back(buf_allocator(streams[rank_idx].get_native())); + + auto& allocator = allocators[rank_idx]; + + sycl::usm::alloc usm_alloc_type; + auto bench_alloc_type = base_coll::get_sycl_usm_type(); + if (bench_alloc_type == SYCL_USM_SHARED) + usm_alloc_type = usm::alloc::shared; + else if (bench_alloc_type == SYCL_USM_DEVICE) + usm_alloc_type = usm::alloc::device; + else + ASSERT(0, "unexpected bench_alloc_type %d", bench_alloc_type); + + for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + send_bufs[idx][rank_idx] = allocator.allocate( + base_coll::get_max_elem_count() * send_multiplier, usm_alloc_type); + recv_bufs[idx][rank_idx] = allocator.allocate( + base_coll::get_max_elem_count() * recv_multiplier, usm_alloc_type); + } + } + else { + for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + send_bufs[idx][rank_idx] = new cl::sycl::buffer( + base_coll::get_max_elem_count() * send_multiplier); + recv_bufs[idx][rank_idx] = new cl::sycl::buffer( + base_coll::get_max_elem_count() * recv_multiplier); + } + } } - single_send_buf = new cl::sycl::buffer( - base_coll::get_single_buf_max_elem_count() * sbuf_multiplier); - - single_recv_buf = new cl::sycl::buffer( - base_coll::get_single_buf_max_elem_count() * rbuf_multiplier); + host_send_buf.resize(base_coll::get_max_elem_count() * send_multiplier); + host_recv_buf.resize(base_coll::get_max_elem_count() * recv_multiplier); } sycl_base_coll(bench_init_attr init_attr) : sycl_base_coll(init_attr, 1, 1) {} virtual ~sycl_base_coll() { - for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { - delete static_cast*>(send_bufs[idx]); - delete static_cast*>(recv_bufs[idx]); + for (size_t rank_idx = 0; rank_idx < base_coll::get_ranks_per_proc(); rank_idx++) { + if (base_coll::get_sycl_mem_type() == SYCL_MEM_BUF) { + for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + delete static_cast*>(send_bufs[idx][rank_idx]); + delete static_cast*>(recv_bufs[idx][rank_idx]); + } + } } - delete static_cast*>(single_send_buf); - delete static_cast*>(single_recv_buf); } const char* name() const noexcept override { @@ -107,50 +102,96 @@ struct sycl_base_coll : base_coll, private strategy, device_data { size_t buf_idx, const bench_exec_attr& attr, req_list_t& reqs) override { - sycl_buffer_t& send_buf = *(static_cast*>(send_bufs[buf_idx])); - sycl_buffer_t& recv_buf = *(static_cast*>(recv_bufs[buf_idx])); - coll_strategy::template start_internal&>( - comm(), - count, - send_buf, - recv_buf, - attr, - reqs, - stream(), - coll_strategy::get_op_attr(attr)); + auto& transport = transport_data::instance(); + auto& comms = transport.get_comms(); + auto streams = transport.get_streams(); + size_t ranks_per_proc = base_coll::get_ranks_per_proc(); + + for (size_t rank_idx = 0; rank_idx < ranks_per_proc; rank_idx++) { + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + coll_strategy::start_internal(comms[rank_idx], + count, + static_cast(send_bufs[buf_idx][rank_idx]), + static_cast(recv_bufs[buf_idx][rank_idx]), + attr, + reqs, + streams[rank_idx], + coll_strategy::get_op_attr(attr)); + } + else { + sycl_buffer_t& send_buf = + *(static_cast*>(send_bufs[buf_idx][rank_idx])); + sycl_buffer_t& recv_buf = + *(static_cast*>(recv_bufs[buf_idx][rank_idx])); + coll_strategy::template start_internal&>( + comms[rank_idx], + count, + send_buf, + recv_buf, + attr, + reqs, + streams[rank_idx], + coll_strategy::get_op_attr(attr)); + } + } } - virtual void start_single(size_t count, - const bench_exec_attr& attr, - req_list_t& reqs) override { - sycl_buffer_t& send_buf = *(static_cast*>(single_send_buf)); - sycl_buffer_t& recv_buf = *(static_cast*>(single_recv_buf)); - coll_strategy::template start_internal&>( - comm(), - count, - send_buf, - recv_buf, - attr, - reqs, - stream(), - coll_strategy::get_op_attr(attr)); + virtual void prepare_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + int comm_rank = comm.rank(); + + size_t send_count = coll_strategy::get_send_multiplier() * elem_count; + size_t recv_count = coll_strategy::get_recv_multiplier() * elem_count; + + size_t send_bytes = send_count * base_coll::get_dtype_size(); + size_t recv_bytes = recv_count * base_coll::get_dtype_size(); + + std::fill(host_send_buf.begin(), host_send_buf.end(), comm_rank); + + for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + stream.get_native() + .memcpy(send_bufs[b_idx][rank_idx], host_send_buf.data(), send_bytes) + .wait(); + + stream.get_native().memset(recv_bufs[b_idx][rank_idx], 0, recv_bytes).wait(); + } + else { + stream.get_native() + .submit([&](handler& h) { + auto send_buf = + (static_cast*>(send_bufs[b_idx][rank_idx])); + auto send_buf_acc = + send_buf->template get_access(h, send_count); + h.fill(send_buf_acc, static_cast(comm_rank)); + }) + .wait(); + + stream.get_native() + .submit([&](handler& h) { + auto recv_buf = + (static_cast*>(recv_bufs[b_idx][rank_idx])); + auto recv_buf_acc = + recv_buf->template get_access(h, recv_count); + h.fill(recv_buf_acc, static_cast(0)); + }) + .wait(); + } + } } ccl::datatype get_dtype() const override final { - return ccl::native_type_info::type>::ccl_datatype_value; + return ccl::native_type_info::type>::dtype; } - /* global communicator & stream for all cpu collectives */ - static ccl::communicator& comm() { - if (!device_data::comm_ptr) { - } - return *device_data::comm_ptr; - } + /* used on fill/check phases */ + std::vector host_send_buf; + std::vector host_recv_buf; - static ccl::stream& stream() { - if (!device_data::stream_ptr) { - } - return *device_data::stream_ptr; - } +private: + std::vector> allocators; }; + #endif /* CCL_ENABLE_SYCL */ diff --git a/examples/benchmark/include/transport.hpp b/examples/benchmark/include/transport.hpp index b1f2b27ed..b0202e1e8 100644 --- a/examples/benchmark/include/transport.hpp +++ b/examples/benchmark/include/transport.hpp @@ -15,23 +15,48 @@ */ #pragma once -#include "base_utils.hpp" +#include +#include -class transport_settings { +#include "oneapi/ccl.hpp" +#include "types.hpp" + +class transport_data { public: - static transport_settings& instance(); + static transport_data& instance(); + static size_t get_comm_size(); + int get_rank() const noexcept; int get_size() const noexcept; ccl::shared_ptr_class get_kvs(); + ccl::communicator& get_service_comm(); + void init_comms(user_options_t& options); + std::vector& get_comms(); + + std::vector& get_streams(); + std::vector& get_bench_streams(); private: - transport_settings(); - ~transport_settings(); + transport_data(); + ~transport_data(); int rank; int size; + + std::vector local_ranks; + ccl::shared_ptr_class kvs; + std::vector service_comms; + std::vector comms; + + /* + FIXME: explicitly separate CCL and bench streams + while runtime doesn't provide MT on the same queue + */ + std::vector streams; + std::vector bench_streams; + void init_by_mpi(); void deinit_by_mpi(); }; diff --git a/examples/benchmark/include/types.hpp b/examples/benchmark/include/types.hpp new file mode 100644 index 000000000..7facfb8d8 --- /dev/null +++ b/examples/benchmark/include/types.hpp @@ -0,0 +1,133 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "oneapi/ccl.hpp" + +#define PRINT(fmt, ...) printf(fmt "\n", ##__VA_ARGS__); + +#ifndef PRINT_BY_ROOT +#define PRINT_BY_ROOT(comm, fmt, ...) \ + if (comm.rank() == 0) { \ + printf(fmt "\n", ##__VA_ARGS__); \ + } +#endif //PRINT_BY_ROOT + +#define ASSERT(cond, fmt, ...) \ + do { \ + if (!(cond)) { \ + printf("FAILED\n"); \ + fprintf(stderr, "ASSERT '%s' FAILED " fmt "\n", #cond, ##__VA_ARGS__); \ + throw std::runtime_error("ASSERT FAILED"); \ + } \ + } while (0) + +// TODO: add ccl::bfloat16 +constexpr std::initializer_list all_dtypes = { + ccl::datatype::int8, ccl::datatype::int32, ccl::datatype::float32, + ccl::datatype::float64, ccl::datatype::int64, ccl::datatype::uint64 +}; + +typedef enum { BACKEND_HOST, BACKEND_SYCL } backend_type_t; +typedef enum { LOOP_REGULAR, LOOP_UNORDERED } loop_type_t; + +typedef enum { SYCL_DEV_HOST, SYCL_DEV_CPU, SYCL_DEV_GPU } sycl_dev_type_t; +typedef enum { SYCL_MEM_USM, SYCL_MEM_BUF } sycl_mem_type_t; +typedef enum { SYCL_USM_SHARED, SYCL_USM_DEVICE } sycl_usm_type_t; + +std::map backend_names = { std::make_pair(BACKEND_HOST, "host"), + std::make_pair(BACKEND_SYCL, "sycl") }; + +std::map loop_names = { std::make_pair(LOOP_REGULAR, "regular"), + std::make_pair(LOOP_UNORDERED, "unordered") }; + +std::map sycl_dev_names = { std::make_pair(SYCL_DEV_HOST, "host"), + std::make_pair(SYCL_DEV_CPU, "cpu"), + std::make_pair(SYCL_DEV_GPU, "gpu") }; + +std::map sycl_mem_names = { std::make_pair(SYCL_MEM_USM, "usm"), + std::make_pair(SYCL_MEM_BUF, "buf") }; + +std::map sycl_usm_names = { std::make_pair(SYCL_USM_SHARED, "shared"), + std::make_pair(SYCL_USM_DEVICE, + "device") }; + +// TODO: add ccl::bfloat16 +std::map dtype_names = { + std::make_pair(ccl::datatype::int8, "int8"), + std::make_pair(ccl::datatype::int32, "int32"), + std::make_pair(ccl::datatype::int64, "int64"), + std::make_pair(ccl::datatype::uint64, "uint64"), + std::make_pair(ccl::datatype::float32, "float32"), + std::make_pair(ccl::datatype::float64, "float64") +}; + +std::map reduction_names = { + std::make_pair(ccl::reduction::sum, "sum"), + std::make_pair(ccl::reduction::prod, "prod"), + std::make_pair(ccl::reduction::min, "min"), + std::make_pair(ccl::reduction::max, "max"), +}; + +std::list tokenize(const std::string& input, char delimeter) { + std::stringstream ss(input); + std::list ret; + std::string value; + while (std::getline(ss, value, delimeter)) { + ret.push_back(value); + } + return ret; +} + +typedef struct user_options_t { + backend_type_t backend; + loop_type_t loop; + size_t iters; + size_t warmup_iters; + size_t buf_count; + size_t min_elem_count; + size_t max_elem_count; + int check_values; + size_t v2i_ratio; + sycl_dev_type_t sycl_dev_type; + sycl_mem_type_t sycl_mem_type; + sycl_usm_type_t sycl_usm_type; + size_t ranks_per_proc; + std::list coll_names; + std::list dtypes; + std::list reductions; + std::string csv_filepath; + + user_options_t() { + backend = DEFAULT_BACKEND; + loop = DEFAULT_LOOP; + iters = DEFAULT_ITERS; + warmup_iters = DEFAULT_WARMUP_ITERS; + buf_count = DEFAULT_BUF_COUNT; + min_elem_count = DEFAULT_MIN_ELEM_COUNT; + max_elem_count = DEFAULT_MAX_ELEM_COUNT; + check_values = DEFAULT_CHECK_VALUES; + v2i_ratio = DEFAULT_V2I_RATIO; + sycl_dev_type = DEFAULT_SYCL_DEV_TYPE; + sycl_mem_type = DEFAULT_SYCL_MEM_TYPE; + sycl_usm_type = DEFAULT_SYCL_USM_TYPE; + ranks_per_proc = DEFAULT_RANKS_PER_PROC; + coll_names = tokenize(DEFAULT_COLL_LIST, ','); + dtypes = tokenize(DEFAULT_DTYPES_LIST, ','); + reductions = tokenize(DEFAULT_REDUCTIONS_LIST, ','); + csv_filepath = std::string(DEFAULT_CSV_FILEPATH); + } +} user_options_t; diff --git a/examples/benchmark/src/allgatherv/allgatherv_strategy.hpp b/examples/benchmark/src/allgatherv/allgatherv_strategy.hpp index b090c0fa4..895b0dfe9 100644 --- a/examples/benchmark/src/allgatherv/allgatherv_strategy.hpp +++ b/examples/benchmark/src/allgatherv/allgatherv_strategy.hpp @@ -16,41 +16,45 @@ #pragma once struct allgatherv_strategy_impl { - size_t comm_size = 0; std::vector recv_counts; - allgatherv_strategy_impl(size_t size) : comm_size(size) { - recv_counts.resize(size); - //int result = posix_memalign((void**)&recv_counts, ALIGNMENT, comm_size * sizeof(size_t)); - //(void)result; + + allgatherv_strategy_impl() { + recv_counts.resize(transport_data::get_comm_size()); } allgatherv_strategy_impl(const allgatherv_strategy_impl&) = delete; allgatherv_strategy_impl& operator=(const allgatherv_strategy_impl&) = delete; - ~allgatherv_strategy_impl() { - //free(recv_counts); - } + ~allgatherv_strategy_impl() {} static constexpr const char* class_name() { return "allgatherv"; } + size_t get_send_multiplier() { + return 1; + } + + size_t get_recv_multiplier() { + return transport_data::get_comm_size(); + } + static const ccl::allgatherv_attr& get_op_attr(const bench_exec_attr& bench_attr) { return bench_attr.get_attr(); } - template - void start_internal(comm_t& comm, + template + void start_internal(ccl::communicator& comm, size_t count, const Dtype send_buf, Dtype recv_buf, const bench_exec_attr& bench_attr, req_list_t& reqs, Args&&... args) { - for (size_t idx = 0; idx < comm_size; idx++) { + for (int idx = 0; idx < comm.size(); idx++) { recv_counts[idx] = count; } - reqs.push_back( - ccl::allgatherv(send_buf, count, recv_buf, recv_counts, comm, std::forward(args)...)); + reqs.push_back(ccl::allgatherv( + send_buf, count, recv_buf, recv_counts, comm, std::forward(args)...)); } }; diff --git a/examples/benchmark/src/allgatherv/cpu_allgatherv_coll.hpp b/examples/benchmark/src/allgatherv/cpu_allgatherv_coll.hpp index e5aaebdd0..795e3259a 100644 --- a/examples/benchmark/src/allgatherv/cpu_allgatherv_coll.hpp +++ b/examples/benchmark/src/allgatherv/cpu_allgatherv_coll.hpp @@ -23,48 +23,35 @@ struct cpu_allgatherv_coll : cpu_base_coll { using coll_base = cpu_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - cpu_allgatherv_coll(bench_init_attr init_attr) - : coll_base(init_attr, 1, coll_base::comm().size(), coll_base::comm().size()) {} + cpu_allgatherv_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - ((Dtype*)send_bufs[b_idx])[e_idx] = coll_base::comm().rank(); - } - - for (size_t idx = 0; idx < coll_base::comm().size(); idx++) { - for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - ((Dtype*)recv_bufs[b_idx])[idx * elem_count + e_idx] = 0; - } - } - } - } - - virtual void finalize(size_t elem_count) override { - Dtype sbuf_expected = coll_base::comm().rank(); + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); Dtype value; for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - value = ((Dtype*)send_bufs[b_idx])[e_idx]; + value = ((Dtype*)send_bufs[b_idx][rank_idx])[e_idx]; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } - for (size_t idx = 0; idx < coll_base::comm().size(); idx++) { + for (int idx = 0; idx < comm.size(); idx++) { Dtype rbuf_expected = idx; for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - value = ((Dtype*)recv_bufs[b_idx])[idx * elem_count + e_idx]; + value = ((Dtype*)recv_bufs[b_idx][rank_idx])[idx * elem_count + e_idx]; if (value != rbuf_expected) { std::cout << this->name() << " recv_bufs: buf_idx " << b_idx - << ", elem_idx " << e_idx << ", expected " << rbuf_expected - << ", got " << value << std::endl; + << ", rank_idx " << rank_idx << ", elem_idx " << e_idx + << ", expected " << rbuf_expected << ", got " << value + << std::endl; ASSERT(0, "unexpected value"); } } diff --git a/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp b/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp index b5bff6f62..c8d465bf6 100644 --- a/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp +++ b/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp @@ -20,105 +20,77 @@ #ifdef CCL_ENABLE_SYCL #include "sycl_coll.hpp" -template -class allgatherv_buf_check {}; - -template -class allatherv_buf_fill {}; - template struct sycl_allgatherv_coll : sycl_base_coll { using coll_base = sycl_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - using coll_base::comm; + using coll_base::host_send_buf; + using coll_base::host_recv_buf; - sycl_allgatherv_coll(bench_init_attr init_attr) - : coll_base(init_attr, 1, coll_base::comm().size(), coll_base::comm().size()) {} + sycl_allgatherv_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - size_t local_rank = coll_base::comm().rank(); - size_t local_size = coll_base::comm().size(); - - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count}, [=](item<1> e_idx) - { - send_buf_acc[e_idx] = local_rank; - for (size_t idx = 0; idx < local_size; idx++) { - recv_buf_acc[idx * elem_count + e_idx.get_id(0)] = 0; - } - }); - }); - } - } + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + int comm_size = comm.size(); + Dtype sbuf_expected = comm.rank(); - virtual void finalize(size_t elem_count) override { - bool unexpected_device_value = false; - size_t local_size = coll_base::comm().size(); - Dtype sbuf_expected = coll_base::comm().rank(); + size_t send_bytes = elem_count * base_coll::get_dtype_size(); + size_t recv_bytes = comm_size * elem_count * base_coll::get_dtype_size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count}, [=](item<1> e_idx) mutable - { - Dtype value = send_buf_acc[e_idx]; - if (value != sbuf_expected) - unexpected_device_value = true; - - for (size_t idx = 0; idx < local_size; idx++) { - Dtype rbuf_expected = idx; - value = recv_buf_acc[idx * elem_count + e_idx.get_id(0)]; - if (value != rbuf_expected) - unexpected_device_value = true; - } - }); - }); - } + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + stream.get_native() + .memcpy(host_send_buf.data(), send_bufs[b_idx][rank_idx], send_bytes) + .wait(); + + stream.get_native() + .memcpy(host_recv_buf.data(), recv_bufs[b_idx][rank_idx], recv_bytes) + .wait(); + } + else { + auto send_buf = (static_cast*>(send_bufs[b_idx][rank_idx])); + auto recv_buf = (static_cast*>(recv_bufs[b_idx][rank_idx])); + auto send_buf_acc = send_buf->template get_access(); + auto recv_buf_acc = recv_buf->template get_access(); + + stream.get_native() + .memcpy(host_send_buf.data(), send_buf_acc.get_pointer(), send_bytes) + .wait(); + + stream.get_native() + .memcpy(host_recv_buf.data(), recv_buf_acc.get_pointer(), recv_bytes) + .wait(); + } - Dtype value; - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(); - auto recv_buf_acc = recv_buf->template get_access(); + Dtype value; for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - value = send_buf_acc[e_idx]; + value = host_send_buf[e_idx]; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } - for (size_t idx = 0; idx < coll_base::comm().size(); idx++) { + for (int idx = 0; idx < comm.size(); idx++) { Dtype rbuf_expected = idx; for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - value = recv_buf_acc[idx * elem_count + e_idx]; + value = host_recv_buf[idx * elem_count + e_idx]; if (value != rbuf_expected) { std::cout << this->name() << " recv_bufs: buf_idx " << b_idx - << ", elem_idx " << e_idx << ", expected " << rbuf_expected - << ", got " << value << std::endl; + << ", rank_idx " << rank_idx << ", elem_idx " << e_idx + << ", expected " << rbuf_expected << ", got " << value + << std::endl; ASSERT(0, "unexpected value"); } } } } - - if (unexpected_device_value) - ASSERT(0, "unexpected value on device"); } }; diff --git a/examples/benchmark/src/allreduce/allreduce_strategy.hpp b/examples/benchmark/src/allreduce/allreduce_strategy.hpp index 61ef34097..b0ed32a0c 100644 --- a/examples/benchmark/src/allreduce/allreduce_strategy.hpp +++ b/examples/benchmark/src/allreduce/allreduce_strategy.hpp @@ -20,12 +20,20 @@ struct allreduce_strategy_impl { return "allreduce"; } + size_t get_send_multiplier() { + return 1; + } + + size_t get_recv_multiplier() { + return 1; + } + static const ccl::allreduce_attr& get_op_attr(const bench_exec_attr& bench_attr) { return bench_attr.get_attr(); } - template - void start_internal(comm_t& comm, + template + void start_internal(ccl::communicator& comm, size_t count, const Dtype send_buf, Dtype recv_buf, diff --git a/examples/benchmark/src/allreduce/cpu_allreduce_coll.hpp b/examples/benchmark/src/allreduce/cpu_allreduce_coll.hpp index e0c6ae072..299c80692 100644 --- a/examples/benchmark/src/allreduce/cpu_allreduce_coll.hpp +++ b/examples/benchmark/src/allreduce/cpu_allreduce_coll.hpp @@ -23,41 +23,31 @@ struct cpu_allreduce_coll : cpu_base_coll { using coll_base = cpu_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - cpu_allreduce_coll(bench_init_attr init_attr) - : coll_base(init_attr) {} + cpu_allreduce_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - ((Dtype*)send_bufs[b_idx])[e_idx] = coll_base::comm().rank(); - ((Dtype*)recv_bufs[b_idx])[e_idx] = 0; - } - } - } - - virtual void finalize(size_t elem_count) override { - Dtype sbuf_expected = coll_base::comm().rank(); - Dtype rbuf_expected = - (coll_base::comm().size() - 1) * ((float)coll_base::comm().size() / 2); + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); + Dtype rbuf_expected = (comm.size() - 1) * ((float)comm.size() / 2); Dtype value; for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - value = ((Dtype*)send_bufs[b_idx])[e_idx]; + value = ((Dtype*)send_bufs[b_idx][rank_idx])[e_idx]; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } - value = ((Dtype*)recv_bufs[b_idx])[e_idx]; + value = ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx]; if (value != rbuf_expected) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << rbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << rbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } diff --git a/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp b/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp index 2e4dcd63e..52b5aaf98 100644 --- a/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp +++ b/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp @@ -20,93 +20,71 @@ #ifdef CCL_ENABLE_SYCL #include "sycl_coll.hpp" -template -class allreduce_buf_check {}; - -template -class allreduce_buf_fill {}; - template struct sycl_allreduce_coll : sycl_base_coll { using coll_base = sycl_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - using coll_base::comm; + using coll_base::host_send_buf; + using coll_base::host_recv_buf; - sycl_allreduce_coll(bench_init_attr init_attr) - : coll_base(init_attr) {} + sycl_allreduce_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - size_t local_rank = coll_base::comm().rank(); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count}, [=](item<1> e_idx) - { - send_buf_acc[e_idx] = local_rank; - recv_buf_acc[e_idx] = 0; - }); - }); - } - } + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); + Dtype rbuf_expected = (comm.size() - 1) * ((float)comm.size() / 2); - virtual void finalize(size_t elem_count) override { - bool unexpected_device_value = false; - Dtype sbuf_expected = coll_base::comm().rank(); - Dtype rbuf_expected = - (coll_base::comm().size() - 1) * ((float)coll_base::comm().size() / 2); + size_t send_bytes = elem_count * base_coll::get_dtype_size(); + size_t recv_bytes = elem_count * base_coll::get_dtype_size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count}, [=](item<1> e_idx) mutable - { - Dtype value = send_buf_acc[e_idx]; - if (value != sbuf_expected) - unexpected_device_value = true; + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + stream.get_native() + .memcpy(host_send_buf.data(), send_bufs[b_idx][rank_idx], send_bytes) + .wait(); - value = recv_buf_acc[e_idx]; - if (value != rbuf_expected) - unexpected_device_value = true; - }); - }); - } + stream.get_native() + .memcpy(host_recv_buf.data(), recv_bufs[b_idx][rank_idx], recv_bytes) + .wait(); + } + else { + auto send_buf = (static_cast*>(send_bufs[b_idx][rank_idx])); + auto recv_buf = (static_cast*>(recv_bufs[b_idx][rank_idx])); + auto send_buf_acc = send_buf->template get_access(); + auto recv_buf_acc = recv_buf->template get_access(); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(); - auto recv_buf_acc = recv_buf->template get_access(); + stream.get_native() + .memcpy(host_send_buf.data(), send_buf_acc.get_pointer(), send_bytes) + .wait(); + + stream.get_native() + .memcpy(host_recv_buf.data(), recv_buf_acc.get_pointer(), recv_bytes) + .wait(); + } + + Dtype value; for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - Dtype value = send_buf_acc[e_idx]; + value = host_send_buf[e_idx]; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } - value = recv_buf_acc[e_idx]; + value = host_recv_buf[e_idx]; if (value != rbuf_expected) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << rbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << rbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } } - - if (unexpected_device_value) - ASSERT(0, "unexpected value on device"); } }; #endif /* CCL_ENABLE_SYCL */ diff --git a/examples/benchmark/src/alltoall/alltoall_strategy.hpp b/examples/benchmark/src/alltoall/alltoall_strategy.hpp index 9a83cbc6d..234ffefa3 100644 --- a/examples/benchmark/src/alltoall/alltoall_strategy.hpp +++ b/examples/benchmark/src/alltoall/alltoall_strategy.hpp @@ -20,12 +20,20 @@ struct alltoall_strategy_impl { return "alltoall"; } + size_t get_send_multiplier() { + return transport_data::get_comm_size(); + } + + size_t get_recv_multiplier() { + return transport_data::get_comm_size(); + } + static const ccl::alltoall_attr& get_op_attr(const bench_exec_attr& bench_attr) { return bench_attr.get_attr(); } - template - void start_internal(comm_t& comm, + template + void start_internal(ccl::communicator& comm, size_t count, const Dtype send_buf, Dtype recv_buf, diff --git a/examples/benchmark/src/alltoall/cpu_alltoall_coll.hpp b/examples/benchmark/src/alltoall/cpu_alltoall_coll.hpp index 942c48776..6e4458ca2 100644 --- a/examples/benchmark/src/alltoall/cpu_alltoall_coll.hpp +++ b/examples/benchmark/src/alltoall/cpu_alltoall_coll.hpp @@ -23,44 +23,34 @@ struct cpu_alltoall_coll : cpu_base_coll { using coll_base = cpu_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - cpu_alltoall_coll(bench_init_attr init_attr) - : coll_base(init_attr, coll_base::comm().size(), coll_base::comm().size()) {} + cpu_alltoall_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - for (size_t idx = 0; idx < coll_base::comm().size(); idx++) { - for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - ((Dtype*)send_bufs[b_idx])[idx * elem_count + e_idx] = coll_base::comm().rank(); - ((Dtype*)recv_bufs[b_idx])[idx * elem_count + e_idx] = 0; - } - } - } - } - - virtual void finalize(size_t elem_count) override { - Dtype sbuf_expected = coll_base::comm().rank(); + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); Dtype rbuf_expected; Dtype value; - size_t comm_size = coll_base::comm().size(); + int comm_size = comm.size(); + for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count * comm_size; e_idx++) { - value = ((Dtype*)send_bufs[b_idx])[e_idx]; + value = ((Dtype*)send_bufs[b_idx][rank_idx])[e_idx]; rbuf_expected = e_idx / elem_count; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } - value = ((Dtype*)recv_bufs[b_idx])[e_idx]; + value = ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx]; if (value != rbuf_expected) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << rbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << rbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } diff --git a/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp b/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp index 7794884c5..9400551f8 100644 --- a/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp +++ b/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp @@ -20,94 +20,72 @@ #ifdef CCL_ENABLE_SYCL #include "sycl_coll.hpp" -template -class alltoall_buf_check {}; - -template -class alltoall_buf_fill {}; - template struct sycl_alltoall_coll : sycl_base_coll { using coll_base = sycl_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - using coll_base::comm; + using coll_base::host_send_buf; + using coll_base::host_recv_buf; - sycl_alltoall_coll(bench_init_attr init_attr) - : coll_base(init_attr, coll_base::comm().size(), coll_base::comm().size()) {} + sycl_alltoall_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - size_t local_rank = coll_base::comm().rank(); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count*coll_base::comm().size()}, [=](item<1> e_idx) - { - send_buf_acc[e_idx] = local_rank; - recv_buf_acc[e_idx] = 0; - }); - }); - } - } + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); + int comm_size = comm.size(); - virtual void finalize(size_t elem_count) override { - bool unexpected_device_value = false; - Dtype sbuf_expected = coll_base::comm().rank(); - size_t comm_size = coll_base::comm().size(); + size_t send_bytes = comm_size * elem_count * base_coll::get_dtype_size(); + size_t recv_bytes = comm_size * elem_count * base_coll::get_dtype_size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count * comm_size}, [=](item<1> e_idx) mutable - { - Dtype value = send_buf_acc[e_idx]; - Dtype rbuf_expected = static_cast(e_idx.get_id(0) / elem_count); - if (value != sbuf_expected) - unexpected_device_value = true; + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + stream.get_native() + .memcpy(host_send_buf.data(), send_bufs[b_idx][rank_idx], send_bytes) + .wait(); - value = recv_buf_acc[e_idx]; - if (value != rbuf_expected) - unexpected_device_value = true; - }); - }); - } + stream.get_native() + .memcpy(host_recv_buf.data(), recv_bufs[b_idx][rank_idx], recv_bytes) + .wait(); + } + else { + auto send_buf = (static_cast*>(send_bufs[b_idx][rank_idx])); + auto recv_buf = (static_cast*>(recv_bufs[b_idx][rank_idx])); + auto send_buf_acc = send_buf->template get_access(); + auto recv_buf_acc = recv_buf->template get_access(); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(); - auto recv_buf_acc = recv_buf->template get_access(); + stream.get_native() + .memcpy(host_send_buf.data(), send_buf_acc.get_pointer(), send_bytes) + .wait(); + + stream.get_native() + .memcpy(host_recv_buf.data(), recv_buf_acc.get_pointer(), recv_bytes) + .wait(); + } + + Dtype value; for (size_t e_idx = 0; e_idx < elem_count * comm_size; e_idx++) { - Dtype value = send_buf_acc[e_idx]; + value = host_send_buf[e_idx]; Dtype rbuf_expected = e_idx / elem_count; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } - value = recv_buf_acc[e_idx]; + value = host_recv_buf[e_idx]; if (value != rbuf_expected) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << rbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << rbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } } - - if (unexpected_device_value) - ASSERT(0, "unexpected value on device"); } }; #endif /* CCL_ENABLE_SYCL */ diff --git a/examples/benchmark/src/alltoallv/alltoallv_strategy.hpp b/examples/benchmark/src/alltoallv/alltoallv_strategy.hpp index 4456c229a..257f77067 100644 --- a/examples/benchmark/src/alltoallv/alltoallv_strategy.hpp +++ b/examples/benchmark/src/alltoallv/alltoallv_strategy.hpp @@ -16,13 +16,12 @@ #pragma once struct alltoallv_strategy_impl { - size_t comm_size = 0; std::vector send_counts; std::vector recv_counts; - alltoallv_strategy_impl(size_t size) : comm_size(size) { - send_counts.resize(comm_size); - recv_counts.resize(comm_size); + alltoallv_strategy_impl() { + send_counts.resize(transport_data::get_comm_size()); + recv_counts.resize(transport_data::get_comm_size()); } alltoallv_strategy_impl(const alltoallv_strategy_impl&) = delete; @@ -34,19 +33,27 @@ struct alltoallv_strategy_impl { return "alltoallv"; } + size_t get_send_multiplier() { + return transport_data::get_comm_size(); + } + + size_t get_recv_multiplier() { + return transport_data::get_comm_size(); + } + static const ccl::alltoallv_attr& get_op_attr(const bench_exec_attr& bench_attr) { return bench_attr.get_attr(); } - template - void start_internal(comm_t& comm, + template + void start_internal(ccl::communicator& comm, size_t count, const Dtype send_buf, Dtype recv_buf, const bench_exec_attr& bench_attr, req_list_t& reqs, Args&&... args) { - for (size_t idx = 0; idx < comm_size; idx++) { + for (int idx = 0; idx < comm.size(); idx++) { send_counts[idx] = count; recv_counts[idx] = count; } diff --git a/examples/benchmark/src/alltoallv/cpu_alltoallv_coll.hpp b/examples/benchmark/src/alltoallv/cpu_alltoallv_coll.hpp index 50ac640ae..58eea5922 100644 --- a/examples/benchmark/src/alltoallv/cpu_alltoallv_coll.hpp +++ b/examples/benchmark/src/alltoallv/cpu_alltoallv_coll.hpp @@ -23,48 +23,33 @@ struct cpu_alltoallv_coll : cpu_base_coll { using coll_base = cpu_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - using coll_base::comm; - cpu_alltoallv_coll(bench_init_attr init_attr) - : coll_base(init_attr, - coll_base::comm().size(), - coll_base::comm().size(), - coll_base::comm().size()) {} + cpu_alltoallv_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - for (size_t idx = 0; idx < coll_base::comm().size(); idx++) { - for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - ((Dtype*)send_bufs[b_idx])[idx * elem_count + e_idx] = coll_base::comm().rank(); - ((Dtype*)recv_bufs[b_idx])[idx * elem_count + e_idx] = 0; - } - } - } - } - - virtual void finalize(size_t elem_count) override { - Dtype sbuf_expected = coll_base::comm().rank(); + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); Dtype rbuf_expected; Dtype value; - size_t comm_size = coll_base::comm().size(); + int comm_size = comm.size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count * comm_size; e_idx++) { - value = ((Dtype*)send_bufs[b_idx])[e_idx]; + value = ((Dtype*)send_bufs[b_idx][rank_idx])[e_idx]; rbuf_expected = e_idx / elem_count; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } - value = ((Dtype*)recv_bufs[b_idx])[e_idx]; + value = ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx]; if (value != rbuf_expected) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << rbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << rbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } diff --git a/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp b/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp index 90a49218a..3b018de31 100644 --- a/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp +++ b/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp @@ -20,97 +20,72 @@ #ifdef CCL_ENABLE_SYCL #include "sycl_coll.hpp" -template -class alltoallv_buf_check {}; - -template -class alltoallv_buf_fill {}; - template struct sycl_alltoallv_coll : sycl_base_coll { using coll_base = sycl_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - using coll_base::comm; + using coll_base::host_send_buf; + using coll_base::host_recv_buf; - sycl_alltoallv_coll(bench_init_attr init_attr) - : coll_base(init_attr, - coll_base::comm().size(), - coll_base::comm().size(), - coll_base::comm().size()) {} + sycl_alltoallv_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - size_t local_rank = coll_base::comm().rank(); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count*coll_base::comm().size()}, [=](item<1> e_idx) - { - send_buf_acc[e_idx] = local_rank; - recv_buf_acc[e_idx] = 0; - }); - }); - } - } + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); + int comm_size = comm.size(); - virtual void finalize(size_t elem_count) override { - bool unexpected_device_value = false; - Dtype sbuf_expected = coll_base::comm().rank(); - size_t comm_size = coll_base::comm().size(); + size_t send_bytes = comm_size * elem_count * base_coll::get_dtype_size(); + size_t recv_bytes = comm_size * elem_count * base_coll::get_dtype_size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count * comm_size}, [=](item<1> e_idx) mutable - { - Dtype value = send_buf_acc[e_idx]; - Dtype rbuf_expected = static_cast(e_idx.get_id(0) / elem_count); - if (value != sbuf_expected) - unexpected_device_value = true; + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + stream.get_native() + .memcpy(host_send_buf.data(), send_bufs[b_idx][rank_idx], send_bytes) + .wait(); - value = recv_buf_acc[e_idx]; - if (value != rbuf_expected) - unexpected_device_value = true; - }); - }); - } + stream.get_native() + .memcpy(host_recv_buf.data(), recv_bufs[b_idx][rank_idx], recv_bytes) + .wait(); + } + else { + auto send_buf = (static_cast*>(send_bufs[b_idx][rank_idx])); + auto recv_buf = (static_cast*>(recv_bufs[b_idx][rank_idx])); + auto send_buf_acc = send_buf->template get_access(); + auto recv_buf_acc = recv_buf->template get_access(); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(); - auto recv_buf_acc = recv_buf->template get_access(); + stream.get_native() + .memcpy(host_send_buf.data(), send_buf_acc.get_pointer(), send_bytes) + .wait(); + + stream.get_native() + .memcpy(host_recv_buf.data(), recv_buf_acc.get_pointer(), recv_bytes) + .wait(); + } + + Dtype value; for (size_t e_idx = 0; e_idx < elem_count * comm_size; e_idx++) { - Dtype value = send_buf_acc[e_idx]; + value = host_send_buf[e_idx]; Dtype rbuf_expected = e_idx / elem_count; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } - value = recv_buf_acc[e_idx]; + value = host_recv_buf[e_idx]; if (value != rbuf_expected) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << rbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << rbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } } - - if (unexpected_device_value) - ASSERT(0, "unexpected value on device"); } }; #endif /* CCL_ENABLE_SYCL */ diff --git a/examples/benchmark/src/bcast/bcast_strategy.hpp b/examples/benchmark/src/bcast/bcast_strategy.hpp index 0a1623b9f..369afd28d 100644 --- a/examples/benchmark/src/bcast/bcast_strategy.hpp +++ b/examples/benchmark/src/bcast/bcast_strategy.hpp @@ -20,12 +20,20 @@ struct bcast_strategy_impl { return "bcast"; } + size_t get_send_multiplier() { + return 1; + } + + size_t get_recv_multiplier() { + return 1; + } + static const ccl::broadcast_attr& get_op_attr(const bench_exec_attr& bench_attr) { return bench_attr.get_attr(); } - template - void start_internal(comm_t& comm, + template + void start_internal(ccl::communicator& comm, size_t count, Dtype send_buf, Dtype recv_buf, @@ -33,6 +41,7 @@ struct bcast_strategy_impl { req_list_t& reqs, Args&&... args) { (void)send_buf; - reqs.push_back(ccl::broadcast(recv_buf, count, COLL_ROOT, comm, std::forward(args)...)); + reqs.push_back( + ccl::broadcast(recv_buf, count, COLL_ROOT, comm, std::forward(args)...)); } }; diff --git a/examples/benchmark/src/bcast/cpu_bcast_coll.hpp b/examples/benchmark/src/bcast/cpu_bcast_coll.hpp index cf1066db2..d6b2d8945 100644 --- a/examples/benchmark/src/bcast/cpu_bcast_coll.hpp +++ b/examples/benchmark/src/bcast/cpu_bcast_coll.hpp @@ -22,31 +22,35 @@ template struct cpu_bcast_coll : cpu_base_coll { using coll_base = cpu_base_coll; using coll_base::recv_bufs; - using coll_base::single_recv_buf; - using coll_base::comm; - cpu_bcast_coll(bench_init_attr init_attr) - : coll_base(init_attr) {} + cpu_bcast_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { + virtual void prepare_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - if (coll_base::comm().rank() == COLL_ROOT) - ((Dtype*)recv_bufs[b_idx])[e_idx] = e_idx; + if (comm.rank() == COLL_ROOT) + ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx] = e_idx; else - ((Dtype*)recv_bufs[b_idx])[e_idx] = 0; + ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx] = 0; } } } - virtual void finalize(size_t elem_count) override { + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { Dtype value; for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - value = ((Dtype*)recv_bufs[b_idx])[e_idx]; + value = ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx]; if (static_cast(value) != e_idx) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << e_idx << ", got " << value << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " << e_idx + << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } diff --git a/examples/benchmark/src/bcast/sycl_bcast_coll.hpp b/examples/benchmark/src/bcast/sycl_bcast_coll.hpp index 428e6cb95..863f62a01 100644 --- a/examples/benchmark/src/bcast/sycl_bcast_coll.hpp +++ b/examples/benchmark/src/bcast/sycl_bcast_coll.hpp @@ -20,71 +20,85 @@ #ifdef CCL_ENABLE_SYCL #include "sycl_coll.hpp" -template -class bcast_buf_check {}; - -template -class bcast_buf_fill {}; - template struct sycl_bcast_coll : sycl_base_coll { using coll_base = sycl_base_coll; using coll_base::recv_bufs; - using coll_base::single_recv_buf; - using coll_base::comm; + using coll_base::host_recv_buf; - sycl_bcast_coll(bench_init_attr init_attr) - : coll_base(init_attr) {} + sycl_bcast_coll(bench_init_attr init_attr) : coll_base(init_attr) {} + + virtual void prepare_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + int comm_rank = comm.rank(); + + size_t count = elem_count; + size_t bytes = count * base_coll::get_dtype_size(); + + std::iota(host_recv_buf.begin(), host_recv_buf.end(), 0); - virtual void prepare(size_t elem_count) override { - size_t local_rank = coll_base::comm().rank(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count}, [=](item<1> e_idx) - { - if (local_rank == COLL_ROOT) - recv_buf_acc[e_idx] = e_idx.get_id(0); - else - recv_buf_acc[e_idx] = 0; - }); - }); + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + if (comm_rank == COLL_ROOT) + stream.get_native() + .memcpy(recv_bufs[b_idx][rank_idx], host_recv_buf.data(), bytes) + .wait(); + else + stream.get_native().memset(recv_bufs[b_idx][rank_idx], 0, bytes).wait(); + } + else { + stream.get_native() + .submit([&](handler& h) { + auto recv_buf = + (static_cast*>(recv_bufs[b_idx][rank_idx])); + auto recv_buf_acc = recv_buf->template get_access(h); + h.parallel_for(range<1>{ elem_count }, [=](item<1> e_idx) { + if (comm_rank == COLL_ROOT) + recv_buf_acc[e_idx] = e_idx.get_id(0); + else + recv_buf_acc[e_idx] = 0; + }); + }) + .wait(); + } } } - virtual void finalize(size_t elem_count) override { - bool unexpected_device_value = false; + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + size_t bytes = elem_count * base_coll::get_dtype_size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count}, [=](item<1> e_idx) mutable - { - if (recv_buf_acc[e_idx] != e_idx.get_id(0)) - unexpected_device_value = true; - }); - }); - } + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + stream.get_native() + .memcpy(host_recv_buf.data(), recv_bufs[b_idx][rank_idx], bytes) + .wait(); + } + else { + auto recv_buf = (static_cast*>(recv_bufs[b_idx][rank_idx])); + auto recv_buf_acc = recv_buf->template get_access(); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto recv_buf_acc = recv_buf->template get_access(); + stream.get_native() + .memcpy(host_recv_buf.data(), recv_buf_acc.get_pointer(), bytes) + .wait(); + } + + Dtype value; for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - Dtype value = recv_buf_acc[e_idx]; + value = host_recv_buf[e_idx]; if (value != e_idx) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << (Dtype)e_idx << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " << (Dtype)e_idx + << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } } - - if (unexpected_device_value) - ASSERT(0, "unexpected value on device"); } }; #endif /* CCL_ENABLE_SYCL */ diff --git a/examples/benchmark/src/benchmark.cpp b/examples/benchmark/src/benchmark.cpp index b9a3597f6..314d7aa32 100644 --- a/examples/benchmark/src/benchmark.cpp +++ b/examples/benchmark/src/benchmark.cpp @@ -27,15 +27,13 @@ #include "benchmark.hpp" #include "declarations.hpp" - #include "transport_impl.hpp" -void do_regular(const ccl::communicator& comm, +void do_regular(ccl::communicator& service_comm, bench_exec_attr& bench_attr, coll_list_t& all_colls, req_list_t& reqs, const user_options_t& options) { - std::stringstream match_id_stream; for (auto dtype : all_dtypes) { @@ -59,187 +57,121 @@ void do_regular(const ccl::communicator& comm, if (!find_key_val(reduction_op, reduction_names, reduction)) continue; - PRINT_BY_ROOT( - comm, "\ndtype: %s\nreduction: %s\n", dtype_name.c_str(), reduction.c_str()); + PRINT_BY_ROOT(service_comm, + "\ndtype: %s\nreduction: %s\n", + dtype_name.c_str(), + reduction.c_str()); reqs.reserve(colls.size() * options.buf_count); - /* warm up */ - PRINT_BY_ROOT(comm, "do warm up"); - bench_attr.reduction = reduction_op; bench_attr.set(true); - ccl::barrier(comm); - - for (size_t count = options.min_elem_count; count <= options.max_elem_count; - count *= 2) { - for (size_t iter_idx = 0; iter_idx < options.warmup_iters; iter_idx++) { - for (size_t coll_idx = 0; coll_idx < colls.size(); coll_idx++) { - auto& coll = colls[coll_idx]; - for (size_t buf_idx = 0; buf_idx < options.buf_count; buf_idx++) { - match_id_stream << "coll_" << coll->name() - << "_" << coll_idx << "_count_" << count - << "_buf_" << buf_idx; - bench_attr.set(match_id_stream.str()); - match_id_stream.str(""); - coll->start(count, buf_idx, bench_attr, reqs); - } - } - for (auto& req : reqs) { - req.wait(); - } - reqs.clear(); - } - } - std::ostringstream scolls; std::copy(options.coll_names.begin(), options.coll_names.end(), std::ostream_iterator{ scolls, " " }); - ccl::barrier(comm); + ccl::barrier(service_comm); /* benchmark with multiple equal sized buffer per collective */ - if (options.buf_type == BUF_MULTI) { - PRINT_BY_ROOT(comm, - "do multi-buffers benchmark\n" - "#------------------------------------------------------------\n" - "# Benchmarking: %s\n" - "# ranks: %zu\n" - "#------------------------------------------------------------\n" + PRINT_BY_ROOT(service_comm, + "#------------------------------------------------------------\n" + "# Benchmarking: %s\n" + "# processes: %d\n" + "#------------------------------------------------------------\n", + scolls.str().c_str(), + service_comm.size()); + + if (options.buf_count == 1) { + PRINT_BY_ROOT(service_comm, "%10s %12s %11s", "#bytes", "avg[usec]", "stddev[%]"); + } + else { + PRINT_BY_ROOT(service_comm, "%10s %13s %18s %11s", - scolls.str().c_str(), - comm.size(), "#bytes", "avg[usec]", "avg_per_buf[usec]", "stddev[%]"); - bench_attr.set(true); - for (size_t count = options.min_elem_count; count <= options.max_elem_count; - count *= 2) { - try { - // we store times for each collective separately, - // but aggregate over buffers and iterations - std::vector coll_timers(colls.size(), 0); - for (size_t coll_idx = 0; coll_idx < colls.size(); coll_idx++) { - ccl::barrier(comm); - - double t1 = 0, t2 = 0, t = 0; - - for (size_t iter_idx = 0; iter_idx < options.iters; iter_idx++) { - auto& coll = colls[coll_idx]; - // collective is configured to handle only - // options.buf_count many buffers/executions 'at once'. - // -> check cannot combine executions over iterations - // -> wait and check and must be in this loop nest - if (options.check_values) { - coll->prepare(count); - } - - ccl::barrier(comm); - - t1 = when(); - - for (size_t buf_idx = 0; buf_idx < options.buf_count; buf_idx++) { - match_id_stream << "coll_" << coll->name() - << "_" << coll_idx << "_count_" << count - << "_buf_" << buf_idx; - bench_attr.set(match_id_stream.str()); - match_id_stream.str(""); - coll->start(count, buf_idx, bench_attr, reqs); - } - - for (auto& req : reqs) { - req.wait(); - } - reqs.clear(); - - t2 = when(); - t += (t2 - t1); - - if (options.check_values) { - coll->finalize(count); - } - } - coll_timers[coll_idx] += t; - } - print_timings(comm, coll_timers, options, count, dtype, reduction_op); - } - catch (const std::exception& ex) { - ASSERT(0, "error on count %zu, reason: %s", count, ex.what()); - } - } } - else { - /* benchmark with single buffer per collective */ - PRINT_BY_ROOT(comm, - "do single-buffer benchmark\n" - "#--------------------------------------\n" - "# Benchmarking: %s\n" - "# ranks: %zu\n" - "#--------------------------------------\n" - "%10s %12s %11s", - scolls.str().c_str(), - comm.size(), - "#bytes", - "avg[usec]", - "stddev[%]"); - size_t min_elem_count = options.min_elem_count * options.buf_count; - size_t max_elem_count = options.max_elem_count * options.buf_count; - bench_attr.set(true); - for (size_t count = min_elem_count; count <= max_elem_count; count *= 2) { - try { - // we store times for each collective separately, - // but aggregate over iterations - std::vector coll_timers(colls.size(), 0); + for (size_t count = options.min_elem_count; count <= options.max_elem_count; + count *= 2) { + size_t iter_count = + get_iter_count(count * ccl::get_datatype_size(dtype), options.iters); - double t1 = 0, t2 = 0; + size_t warmup_iter_count = + get_iter_count(count * ccl::get_datatype_size(dtype), options.warmup_iters); - for (size_t coll_idx = 0; coll_idx < colls.size(); coll_idx++) { - auto& coll = colls[coll_idx]; + try { + // we store times for each collective separately, + // but aggregate over buffers and iterations + std::vector coll_timers(colls.size(), 0); + for (size_t coll_idx = 0; coll_idx < colls.size(); coll_idx++) { + auto& coll = colls[coll_idx]; + + ccl::barrier(service_comm); + + double t1 = 0, t2 = 0, t = 0; + + for (size_t iter_idx = 0; iter_idx < (iter_count + warmup_iter_count); + iter_idx++) { + // collective is configured to handle only + // options.buf_count many buffers/executions 'at once'. + // -> check cannot combine executions over iterations + // -> wait and check and must be in this loop nest + if (options.check_values) { + coll->prepare(count); + } - ccl::barrier(comm); + ccl::barrier(service_comm); t1 = when(); - for (size_t iter_idx = 0; iter_idx < options.iters; iter_idx++) { - match_id_stream << "coll_" << coll->name() - << "_" << coll_idx << "_single_count_" << count; - bench_attr.set(match_id_stream.str()); + for (size_t buf_idx = 0; buf_idx < options.buf_count; buf_idx++) { + match_id_stream << "coll_" << coll->name() << "_" << coll_idx + << "_count_" << count << "_buf_" << buf_idx; + bench_attr.set( + ccl::string_class(match_id_stream.str())); match_id_stream.str(""); - coll->start_single(count, bench_attr, reqs); - for (auto& req : reqs) { - req.wait(); - } - reqs.clear(); + coll->start(count, buf_idx, bench_attr, reqs); } + for (auto& req : reqs) { + req.wait(); + } + reqs.clear(); + t2 = when(); - coll_timers[coll_idx] += (t2 - t1); - } + if (iter_idx >= warmup_iter_count) { + t += (t2 - t1); + } - print_timings(comm, coll_timers, options, count, dtype, reduction_op); - } - catch (...) { - ASSERT(0, "error on count %zu", count); + if (options.check_values) { + coll->finalize(count); + } + } + coll_timers[coll_idx] += t; } + + print_timings( + service_comm, coll_timers, options, count, iter_count, dtype, reduction_op); + } + catch (const std::exception& ex) { + ASSERT(0, "error on count %zu, reason: %s", count, ex.what()); } - PRINT_BY_ROOT(comm, "PASSED\n"); } } } } -void do_unordered(const ccl::communicator& comm, +void do_unordered(ccl::communicator& service_comm, bench_exec_attr& bench_attr, coll_list_t& all_colls, req_list_t& reqs, const user_options_t& options) { - - std::set match_ids; + std::set match_ids; std::stringstream match_id_stream; for (auto dtype : all_dtypes) { @@ -263,14 +195,16 @@ void do_unordered(const ccl::communicator& comm, if (!find_key_val(reduction_op, reduction_names, reduction)) continue; - PRINT_BY_ROOT( - comm, "\ndtype: %s\nreduction: %s\n", dtype_name.c_str(), reduction.c_str()); + PRINT_BY_ROOT(service_comm, + "\ndtype: %s\nreduction: %s\n", + dtype_name.c_str(), + reduction.c_str()); - size_t rank = comm.rank(); + int rank = service_comm.rank(); reqs.reserve(colls.size() * options.buf_count * (log2(options.max_elem_count) + 1)); - PRINT_BY_ROOT(comm, "do unordered test"); + PRINT_BY_ROOT(service_comm, "do unordered test"); bench_attr.reduction = reduction_op; bench_attr.set(true); @@ -281,14 +215,13 @@ void do_unordered(const ccl::communicator& comm, for (size_t coll_idx = 0; coll_idx < colls.size(); coll_idx++) { auto& coll = colls[coll_idx]; for (size_t buf_idx = 0; buf_idx < options.buf_count; buf_idx++) { - match_id_stream << "coll_" << coll->name() - << "_" << coll_idx << "_count_" << count - << "_buf_" << buf_idx; - bench_attr.set(match_id_stream.str()); + match_id_stream << "coll_" << coll->name() << "_" << coll_idx + << "_count_" << count << "_buf_" << buf_idx; + bench_attr.set( + ccl::string_class(match_id_stream.str())); match_ids.insert(match_id_stream.str()); match_id_stream.str(""); coll->start(count, buf_idx, bench_attr, reqs); - } } } @@ -298,10 +231,10 @@ void do_unordered(const ccl::communicator& comm, auto& coll = colls[real_coll_idx]; for (size_t buf_idx = 0; buf_idx < options.buf_count; buf_idx++) { size_t real_buf_idx = options.buf_count - buf_idx - 1; - match_id_stream << "coll_" << coll->name() - << "_" << real_coll_idx << "_count_" << count - << "_buf_" << real_buf_idx; - bench_attr.set(match_id_stream.str()); + match_id_stream << "coll_" << coll->name() << "_" << real_coll_idx + << "_count_" << count << "_buf_" << real_buf_idx; + bench_attr.set( + ccl::string_class(match_id_stream.str())); match_ids.insert(match_id_stream.str()); match_id_stream.str(""); coll->start(count, real_buf_idx, bench_attr, reqs); @@ -327,20 +260,18 @@ void do_unordered(const ccl::communicator& comm, catch (...) { ASSERT(0, "error on coll completion"); } - PRINT_BY_ROOT(comm, "PASSED\n"); + PRINT_BY_ROOT(service_comm, "PASSED\n"); } } } template -void create_cpu_colls(bench_init_attr& init_attr, - user_options_t& options, - coll_list_t& colls) { - using namespace sparse_detail; - using incremental_index_int_sparse_strategy = - sparse_allreduce_strategy_impl; - using incremental_index_bf16_sparse_strategy = - sparse_allreduce_strategy_impl; +void create_cpu_colls(bench_init_attr& init_attr, user_options_t& options, coll_list_t& colls) { + // using namespace sparse_detail; + // using incremental_index_int_sparse_strategy = + // sparse_allreduce_strategy_impl; + // using incremental_index_bf16_sparse_strategy = + // sparse_allreduce_strategy_impl; std::stringstream error_messages_stream; @@ -367,35 +298,35 @@ void create_cpu_colls(bench_init_attr& init_attr, else if (name == reduce_scatter_strategy_impl::class_name()) { colls.emplace_back(new cpu_reduce_scatter_coll(init_attr)); } - else if (name.find(incremental_index_int_sparse_strategy::class_name()) != - std::string::npos) { - if (name.find(incremental_index_bf16_sparse_strategy::class_name()) != - std::string::npos) { - if (is_bf16_enabled() == 0) { - error_messages_stream << "bfloat16 is not supported for current CPU, skipping " - << name << ".\n"; - names_it = options.coll_names.erase(names_it); - continue; - } -#ifdef CCL_bf16_COMPILER - colls.emplace_back( - new cpu_sparse_allreduce_coll( - init_attr, - sizeof(float) / sizeof(ccl::bf16), - sizeof(float) / sizeof(ccl::bf16))); -#else - error_messages_stream << "bfloat16 is not supported by current compiler, skipping " - << name << ".\n"; - names_it = options.coll_names.erase(names_it); - continue; -#endif - } - else { - colls.emplace_back(new cpu_sparse_allreduce_coll(init_attr)); - } - } + // else if (name.find(incremental_index_int_sparse_strategy::class_name()) != + // std::string::npos) { + // if (name.find(incremental_index_bf16_sparse_strategy::class_name()) != + // std::string::npos) { + // if (is_bf16_enabled() == 0) { + // error_messages_stream << "bfloat16 is not supported for current CPU, skipping " + // << name << ".\n"; + // names_it = options.coll_names.erase(names_it); + // continue; + // } + // #ifdef CCL_bf16_COMPILER + // colls.emplace_back( + // new cpu_sparse_allreduce_coll( + // init_attr, + // sizeof(float) / sizeof(ccl::bfloat16), + // sizeof(float) / sizeof(ccl::bfloat16))); + // #else + // error_messages_stream << "bfloat16 is not supported by current compiler, skipping " + // << name << ".\n"; + // names_it = options.coll_names.erase(names_it); + // continue; + // #endif + // } + // else { + // colls.emplace_back(new cpu_sparse_allreduce_coll(init_attr)); + // } + // } else { ASSERT(0, "create_colls error, unknown coll name: %s", name.c_str()); } @@ -415,13 +346,11 @@ void create_cpu_colls(bench_init_attr& init_attr, #ifdef CCL_ENABLE_SYCL template -void create_sycl_colls(bench_init_attr& init_attr, - user_options_t& options, - coll_list_t& colls) { - using incremental_index_int_sparse_strategy = - sparse_allreduce_strategy_impl; - using incremental_index_bf16_sparse_strategy = - sparse_allreduce_strategy_impl; +void create_sycl_colls(bench_init_attr& init_attr, user_options_t& options, coll_list_t& colls) { + // using incremental_index_int_sparse_strategy = + // sparse_allreduce_strategy_impl; + // using incremental_index_bf16_sparse_strategy = + // sparse_allreduce_strategy_impl; std::stringstream error_messages_stream; @@ -449,48 +378,48 @@ void create_sycl_colls(bench_init_attr& init_attr, else if (name == reduce_scatter_strategy_impl::class_name()) { colls.emplace_back(new sycl_reduce_scatter_coll(init_attr)); } - else if (name.find(incremental_index_int_sparse_strategy::class_name()) != - std::string::npos) { - // TODO case is not supported yet - if (true) { - error_messages_stream << "SYCL coll: skipping " << name - << ", because it is not supported yet.\n"; - names_it = options.coll_names.erase(names_it); - continue; - } - colls.emplace_back(new sycl_sparse_allreduce_coll(init_attr)); - } - else if (name.find(incremental_index_bf16_sparse_strategy::class_name()) != - std::string::npos) { - // TODO case is not supported yet - if (true) { - error_messages_stream << "SYCL coll: skipping " << name - << ", because it is not supported yet.\n"; - names_it = options.coll_names.erase(names_it); - continue; - } - - if (is_bf16_enabled() == 0) { - error_messages_stream << "SYCL bf16 is not supported for current CPU, skipping " - << name << ".\n"; - names_it = options.coll_names.erase(names_it); - continue; - } -#ifdef CCL_bf16_COMPILER - colls.emplace_back( - new sycl_sparse_allreduce_coll( - init_attr, - sizeof(float) / sizeof(ccl::bf16), - sizeof(float) / sizeof(ccl::bf16))); -#else - error_messages_stream << "SYCL bf16 is not supported by current compiler, skipping " - << name << ".\n"; - names_it = options.coll_names.erase(names_it); - continue; -#endif - } + // else if (name.find(incremental_index_int_sparse_strategy::class_name()) != + // std::string::npos) { + // // TODO case is not supported yet + // if (true) { + // error_messages_stream << "SYCL coll: skipping " << name + // << ", because it is not supported yet.\n"; + // names_it = options.coll_names.erase(names_it); + // continue; + // } + // colls.emplace_back(new sycl_sparse_allreduce_coll(init_attr)); + // } + // else if (name.find(incremental_index_bf16_sparse_strategy::class_name()) != + // std::string::npos) { + // // TODO case is not supported yet + // if (true) { + // error_messages_stream << "SYCL coll: skipping " << name + // << ", because it is not supported yet.\n"; + // names_it = options.coll_names.erase(names_it); + // continue; + // } + + // if (is_bf16_enabled() == 0) { + // error_messages_stream << "SYCL bf16 is not supported for current CPU, skipping " + // << name << ".\n"; + // names_it = options.coll_names.erase(names_it); + // continue; + // } + // #ifdef CCL_bf16_COMPILER + // colls.emplace_back( + // new sycl_sparse_allreduce_coll( + // init_attr, + // sizeof(float) / sizeof(ccl::bfloat16), + // sizeof(float) / sizeof(ccl::bfloat16))); + // #else + // error_messages_stream << "SYCL bf16 is not supported by current compiler, skipping " + // << name << ".\n"; + // names_it = options.coll_names.erase(names_it); + // continue; + // #endif + // } else { ASSERT(0, "create_colls error, unknown coll name: %s", name.c_str()); } @@ -513,9 +442,7 @@ void create_sycl_colls(bench_init_attr& init_attr, template void create_colls(bench_init_attr& init_attr, user_options_t& options, coll_list_t& colls) { switch (options.backend) { - case BACKEND_HOST: - create_cpu_colls(init_attr, options, colls); - break; + case BACKEND_HOST: create_cpu_colls(init_attr, options, colls); break; case BACKEND_SYCL: #ifdef CCL_ENABLE_SYCL create_sycl_colls(init_attr, options, colls); @@ -527,28 +454,24 @@ void create_colls(bench_init_attr& init_attr, user_options_t& options, coll_list } } -/* Reason to leave a functor here: In order to call a function (create_colls()) - * with all dtypes (from ccl::datatype) the functor requires the implementation - * of that function. */ -class create_colls_func { -private: - bench_init_attr& init_attr; - user_options_t& options; - coll_list_t& colls; - -public: - create_colls_func(bench_init_attr& init_attr, user_options_t& options, coll_list_t& colls) - : init_attr(init_attr), - options(options), - colls(colls) {} - - template - void operator()(const Dtype& value) { - if (true == std::get<0>(value)) { - create_colls(init_attr, options, colls); - } +void create_all_colls(bench_init_attr& init_attr, user_options_t& options, coll_list_t& colls) { + for (auto& dtype : options.dtypes) { + if (dtype == dtype_names[ccl::datatype::int8]) + create_colls(init_attr, options, colls); + else if (dtype == dtype_names[ccl::datatype::int32]) + create_colls(init_attr, options, colls); + else if (dtype == dtype_names[ccl::datatype::int64]) + create_colls(init_attr, options, colls); + else if (dtype == dtype_names[ccl::datatype::uint64]) + create_colls(init_attr, options, colls); + else if (dtype == dtype_names[ccl::datatype::float32]) + create_colls(init_attr, options, colls); + else if (dtype == dtype_names[ccl::datatype::float64]) + create_colls(init_attr, options, colls); + else + ASSERT(0, "unexpected datatype %s", dtype.c_str()); } -}; +} int main(int argc, char* argv[]) { user_options_t options; @@ -558,34 +481,24 @@ int main(int argc, char* argv[]) { bench_init_attr init_attr; if (parse_user_options(argc, argv, options)) { - PRINT("failed to parse user options"); print_help_usage(argv[0]); + return -1; } + auto& transport = transport_data::instance(); + transport.init_comms(options); + + ccl::communicator& service_comm = transport.get_service_comm(); + init_attr.buf_count = options.buf_count; init_attr.max_elem_count = options.max_elem_count; + init_attr.ranks_per_proc = options.ranks_per_proc; + init_attr.sycl_mem_type = options.sycl_mem_type; + init_attr.sycl_usm_type = options.sycl_usm_type; init_attr.v2i_ratio = options.v2i_ratio; - host_data::init(transport_settings::instance().get_size(), - transport_settings::instance().get_rank(), - transport_settings::instance().get_kvs()); -#ifdef CCL_ENABLE_SYCL - if (options.backend == BACKEND_SYCL) { - - auto dev = get_device(*host_data::comm_ptr); - cl::sycl::context ctx(dev); - - device_data::init(transport_settings::instance().get_size(), - transport_settings::instance().get_rank(), - dev, - ctx, - transport_settings::instance().get_kvs()); - } -#endif - try { - ccl_tuple_for_each(launch_dtypes, set_dtypes_func(options.dtypes)); - ccl_tuple_for_each(launch_dtypes, create_colls_func(init_attr, options, colls)); + create_all_colls(init_attr, options, colls); } catch (const std::runtime_error& e) { ASSERT(0, "cannot create coll objects: %s\n", e.what()); @@ -595,25 +508,23 @@ int main(int argc, char* argv[]) { return -1; } - ccl::communicator& comm = *host_data::comm_ptr; - bench_exec_attr bench_attr{}; bench_attr.init_all(); - print_user_options(options, comm); + print_user_options(options, service_comm); if (options.coll_names.empty()) { - PRINT_BY_ROOT(comm, "empty coll list"); + PRINT_BY_ROOT(service_comm, "empty coll list"); print_help_usage(argv[0]); return -1; } - ccl::barrier(comm); + ccl::barrier(service_comm); switch (options.loop) { case LOOP_REGULAR: { // open and truncate CSV file if csv-output is requested - if (comm.rank() == 0 && !options.csv_filepath.empty()) { + if (service_comm.rank() == 0 && !options.csv_filepath.empty()) { std::ofstream csvf; csvf.open(options.csv_filepath, std::ios::trunc); if (!csvf.is_open()) { @@ -626,22 +537,18 @@ int main(int argc, char* argv[]) { << std::endl; csvf.close(); } - ccl::barrier(comm); - do_regular(comm, bench_attr, colls, reqs, options); + ccl::barrier(service_comm); + do_regular(service_comm, bench_attr, colls, reqs, options); break; } case LOOP_UNORDERED: { // no timing is printed or exported here - ccl::barrier(comm); - do_unordered(comm, bench_attr, colls, reqs, options); + ccl::barrier(service_comm); + do_unordered(service_comm, bench_attr, colls, reqs, options); break; } default: ASSERT(0, "unknown loop %d", options.loop); break; } -#ifdef CCL_ENABLE_SYCL - device_data::deinit(); -#endif - host_data::deinit(); return 0; } diff --git a/examples/benchmark/src/declarations.hpp b/examples/benchmark/src/declarations.hpp index 72263566f..7bcd15950 100755 --- a/examples/benchmark/src/declarations.hpp +++ b/examples/benchmark/src/declarations.hpp @@ -51,7 +51,7 @@ #include "reduce_scatter/sycl_reduce_scatter_coll.hpp" /* sparse_allreduce implementation */ -#include "sparse_allreduce/sparse_allreduce_base.hpp" -#include "sparse_allreduce/sparse_allreduce_strategy.hpp" -#include "sparse_allreduce/cpu_sparse_allreduce_coll.hpp" -#include "sparse_allreduce/sycl_sparse_allreduce_coll.hpp" +// #include "sparse_allreduce/sparse_allreduce_base.hpp" +// #include "sparse_allreduce/sparse_allreduce_strategy.hpp" +// #include "sparse_allreduce/cpu_sparse_allreduce_coll.hpp" +// #include "sparse_allreduce/sycl_sparse_allreduce_coll.hpp" diff --git a/examples/benchmark/src/reduce/cpu_reduce_coll.hpp b/examples/benchmark/src/reduce/cpu_reduce_coll.hpp index e049fde6b..0a0c9d445 100644 --- a/examples/benchmark/src/reduce/cpu_reduce_coll.hpp +++ b/examples/benchmark/src/reduce/cpu_reduce_coll.hpp @@ -23,43 +23,34 @@ struct cpu_reduce_coll : cpu_base_coll { using coll_base = cpu_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - cpu_reduce_coll(bench_init_attr init_attr) - : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - ((Dtype*)send_bufs[b_idx])[e_idx] = coll_base::comm().rank(); - ((Dtype*)recv_bufs[b_idx])[e_idx] = 0; - } - } - } + cpu_reduce_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void finalize(size_t elem_count) override { - Dtype sbuf_expected = coll_base::comm().rank(); - Dtype rbuf_expected = - (coll_base::comm().size() - 1) * ((float)coll_base::comm().size() / 2); + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); + Dtype rbuf_expected = (comm.size() - 1) * ((float)comm.size() / 2); Dtype value; for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - value = ((Dtype*)send_bufs[b_idx])[e_idx]; + value = ((Dtype*)send_bufs[b_idx][rank_idx])[e_idx]; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } - if (coll_base::comm().rank() != COLL_ROOT) + if (comm.rank() != COLL_ROOT) continue; - value = ((Dtype*)recv_bufs[b_idx])[e_idx]; + value = ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx]; if (value != rbuf_expected) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << rbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << rbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } diff --git a/examples/benchmark/src/reduce/reduce_strategy.hpp b/examples/benchmark/src/reduce/reduce_strategy.hpp index 5efa01b70..fb3f5efeb 100644 --- a/examples/benchmark/src/reduce/reduce_strategy.hpp +++ b/examples/benchmark/src/reduce/reduce_strategy.hpp @@ -23,12 +23,20 @@ struct reduce_strategy_impl { return "reduce"; } + size_t get_send_multiplier() { + return 1; + } + + size_t get_recv_multiplier() { + return 1; + } + static const ccl::reduce_attr& get_op_attr(const bench_exec_attr& bench_attr) { return bench_attr.get_attr(); } - template - void start_internal(comm_t& comm, + template + void start_internal(ccl::communicator& comm, size_t count, const Dtype send_buf, Dtype recv_buf, diff --git a/examples/benchmark/src/reduce/sycl_reduce_coll.hpp b/examples/benchmark/src/reduce/sycl_reduce_coll.hpp index c714cedc4..4b87ce244 100644 --- a/examples/benchmark/src/reduce/sycl_reduce_coll.hpp +++ b/examples/benchmark/src/reduce/sycl_reduce_coll.hpp @@ -20,99 +20,76 @@ #ifdef CCL_ENABLE_SYCL #include "sycl_coll.hpp" -template -class reduce_buf_check {}; - -template -class reduce_buf_fill {}; - template struct sycl_reduce_coll : sycl_base_coll { using coll_base = sycl_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - using coll_base::comm; + using coll_base::host_send_buf; + using coll_base::host_recv_buf; - sycl_reduce_coll(bench_init_attr init_attr) - : coll_base(init_attr) {} + sycl_reduce_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - size_t local_rank = coll_base::comm().rank(); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count}, [=](item<1> e_idx) - { - send_buf_acc[e_idx] = local_rank; - recv_buf_acc[e_idx] = 0; - }); - }); - } - } + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); + Dtype rbuf_expected = (comm.size() - 1) * ((float)comm.size() / 2); + + int comm_rank = comm.rank(); - virtual void finalize(size_t elem_count) override { - bool unexpected_device_value = false; - Dtype sbuf_expected = coll_base::comm().rank(); - Dtype rbuf_expected = - (coll_base::comm().size() - 1) * ((float)coll_base::comm().size() / 2); - size_t local_rank = coll_base::comm().rank(); + size_t send_bytes = elem_count * base_coll::get_dtype_size(); + size_t recv_bytes = elem_count * base_coll::get_dtype_size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count}, [=](item<1> e_idx) mutable - { - Dtype value = send_buf_acc[e_idx]; - if (value != sbuf_expected) - unexpected_device_value = true; + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + stream.get_native() + .memcpy(host_send_buf.data(), send_bufs[b_idx][rank_idx], send_bytes) + .wait(); - if (local_rank == COLL_ROOT) { - value = recv_buf_acc[e_idx]; - if (value != rbuf_expected) - unexpected_device_value = true; - } - }); - }); - } + stream.get_native() + .memcpy(host_recv_buf.data(), recv_bufs[b_idx][rank_idx], recv_bytes) + .wait(); + } + else { + auto send_buf = (static_cast*>(send_bufs[b_idx][rank_idx])); + auto recv_buf = (static_cast*>(recv_bufs[b_idx][rank_idx])); + auto send_buf_acc = send_buf->template get_access(); + auto recv_buf_acc = recv_buf->template get_access(); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(); - auto recv_buf_acc = recv_buf->template get_access(); + stream.get_native() + .memcpy(host_send_buf.data(), send_buf_acc.get_pointer(), send_bytes) + .wait(); + + stream.get_native() + .memcpy(host_recv_buf.data(), recv_buf_acc.get_pointer(), recv_bytes) + .wait(); + } + + Dtype value; for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - Dtype value = send_buf_acc[e_idx]; + value = host_send_buf[e_idx]; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } - if (local_rank != COLL_ROOT) + if (comm_rank != COLL_ROOT) continue; - value = recv_buf_acc[e_idx]; + value = host_recv_buf[e_idx]; if (value != rbuf_expected) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << rbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << rbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } } - - if (unexpected_device_value) - ASSERT(0, "unexpected value on device"); } }; #endif /* CCL_ENABLE_SYCL */ diff --git a/examples/benchmark/src/reduce_scatter/cpu_reduce_scatter_coll.hpp b/examples/benchmark/src/reduce_scatter/cpu_reduce_scatter_coll.hpp index 93138ce12..f9bf0107a 100644 --- a/examples/benchmark/src/reduce_scatter/cpu_reduce_scatter_coll.hpp +++ b/examples/benchmark/src/reduce_scatter/cpu_reduce_scatter_coll.hpp @@ -23,49 +23,36 @@ struct cpu_reduce_scatter_coll : cpu_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - cpu_reduce_scatter_coll(bench_init_attr init_attr) - : coll_base(init_attr) {} + cpu_reduce_scatter_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - ((Dtype*)send_bufs[b_idx])[e_idx] = coll_base::comm().rank(); - ((Dtype*)recv_bufs[b_idx])[e_idx] = 0; - } - } - } - - virtual void finalize(size_t elem_count) override { - Dtype sbuf_expected = coll_base::comm().rank(); - Dtype rbuf_expected = - (coll_base::comm().size() - 1) * ((float)coll_base::comm().size() / 2); + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); + Dtype rbuf_expected = (comm.size() - 1) * ((float)comm.size() / 2); Dtype value; - size_t recv_elem_count = elem_count / coll_base::comm().size(); + size_t recv_elem_count = elem_count / comm.size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - - value = ((Dtype*)send_bufs[b_idx])[e_idx]; + value = ((Dtype*)send_bufs[b_idx][rank_idx])[e_idx]; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } for (size_t e_idx = 0; e_idx < recv_elem_count; e_idx++) { - - value = ((Dtype*)recv_bufs[b_idx])[e_idx]; + value = ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx]; if (value != rbuf_expected) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << rbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << rbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } diff --git a/examples/benchmark/src/reduce_scatter/reduce_scatter_strategy.hpp b/examples/benchmark/src/reduce_scatter/reduce_scatter_strategy.hpp index 01a5c8adb..f71eb2d71 100644 --- a/examples/benchmark/src/reduce_scatter/reduce_scatter_strategy.hpp +++ b/examples/benchmark/src/reduce_scatter/reduce_scatter_strategy.hpp @@ -23,19 +23,26 @@ struct reduce_scatter_strategy_impl { return "reduce_scatter"; } + size_t get_send_multiplier() { + return 1; + } + + size_t get_recv_multiplier() { + return 1; + } + static const ccl::reduce_scatter_attr& get_op_attr(const bench_exec_attr& bench_attr) { return bench_attr.get_attr(); } - template - void start_internal(comm_t& comm, + template + void start_internal(ccl::communicator& comm, size_t send_count, const Dtype send_buf, Dtype recv_buf, const bench_exec_attr& bench_attr, req_list_t& reqs, Args&&... args) { - size_t recv_count = send_count / comm.size(); if (recv_count == 0) { diff --git a/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp b/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp index 9fa8701c9..186b02a15 100644 --- a/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp +++ b/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp @@ -20,105 +20,75 @@ #ifdef CCL_ENABLE_SYCL #include "sycl_coll.hpp" -template -class reduce_scatter_sbuf_check {}; - -template -class reduce_scatter_rbuf_check {}; - -template -class reduce_scatter_buf_fill {}; - template struct sycl_reduce_scatter_coll : sycl_base_coll { using coll_base = sycl_base_coll; using coll_base::send_bufs; using coll_base::recv_bufs; - using coll_base::single_send_buf; - using coll_base::single_recv_buf; - using coll_base::comm; + using coll_base::host_send_buf; + using coll_base::host_recv_buf; - sycl_reduce_scatter_coll(bench_init_attr init_attr) - : coll_base(init_attr) {} + sycl_reduce_scatter_coll(bench_init_attr init_attr) : coll_base(init_attr) {} - virtual void prepare(size_t elem_count) override { - size_t local_rank = coll_base::comm().rank(); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count}, [=](item<1> e_idx) - { - send_buf_acc[e_idx] = local_rank; - recv_buf_acc[e_idx] = 0; - }); - }); - } - } + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + Dtype sbuf_expected = comm.rank(); + Dtype rbuf_expected = (comm.size() - 1) * ((float)comm.size() / 2); - virtual void finalize(size_t elem_count) override { - bool unexpected_device_value = false; - Dtype sbuf_expected = coll_base::comm().rank(); - Dtype rbuf_expected = - (coll_base::comm().size() - 1) * ((float)coll_base::comm().size() / 2); + size_t recv_elem_count = elem_count / comm.size(); - size_t recv_elem_count = elem_count / coll_base::comm().size(); + size_t send_bytes = elem_count * base_coll::get_dtype_size(); + size_t recv_bytes = elem_count * base_coll::get_dtype_size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{elem_count}, [=](item<1> e_idx) mutable - { - Dtype value = send_buf_acc[e_idx]; - if (value != sbuf_expected) - unexpected_device_value = true; - }); - }); - - device_data::sycl_queue.submit([&](handler& cgh) { - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto recv_buf_acc = recv_buf->template get_access(cgh); - cgh.parallel_for>(range<1>{recv_elem_count}, [=](item<1> e_idx) mutable - { - Dtype value = recv_buf_acc[e_idx]; - if (value != rbuf_expected) - unexpected_device_value = true; - }); - }); - } + if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { + stream.get_native() + .memcpy(host_send_buf.data(), send_bufs[b_idx][rank_idx], send_bytes) + .wait(); + + stream.get_native() + .memcpy(host_recv_buf.data(), recv_bufs[b_idx][rank_idx], recv_bytes) + .wait(); + } + else { + auto send_buf = (static_cast*>(send_bufs[b_idx][rank_idx])); + auto recv_buf = (static_cast*>(recv_bufs[b_idx][rank_idx])); + auto send_buf_acc = send_buf->template get_access(); + auto recv_buf_acc = recv_buf->template get_access(); + + stream.get_native() + .memcpy(host_send_buf.data(), send_buf_acc.get_pointer(), send_bytes) + .wait(); + + stream.get_native() + .memcpy(host_recv_buf.data(), recv_buf_acc.get_pointer(), recv_bytes) + .wait(); + } - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - auto send_buf = (static_cast*>(send_bufs[b_idx])); - auto recv_buf = (static_cast*>(recv_bufs[b_idx])); - auto send_buf_acc = send_buf->template get_access(); - auto recv_buf_acc = recv_buf->template get_access(); + Dtype value; for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - Dtype value = send_buf_acc[e_idx]; + value = host_send_buf[e_idx]; if (value != sbuf_expected) { - std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << sbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << sbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } for (size_t e_idx = 0; e_idx < recv_elem_count; e_idx++) { - Dtype value = recv_buf_acc[e_idx]; + Dtype value = host_recv_buf[e_idx]; if (value != rbuf_expected) { - std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", elem_idx " - << e_idx << ", expected " << rbuf_expected << ", got " << value - << std::endl; + std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " + << rank_idx << ", elem_idx " << e_idx << ", expected " + << rbuf_expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } } } - - if (unexpected_device_value) - ASSERT(0, "unexpected value on device"); } }; #endif /* CCL_ENABLE_SYCL */ diff --git a/examples/benchmark/src/sparse_allreduce/cpu_sparse_allreduce_coll.hpp b/examples/benchmark/src/sparse_allreduce/cpu_sparse_allreduce_coll.hpp index 0aef29bba..1445a852a 100644 --- a/examples/benchmark/src/sparse_allreduce/cpu_sparse_allreduce_coll.hpp +++ b/examples/benchmark/src/sparse_allreduce/cpu_sparse_allreduce_coll.hpp @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +#if 0 + #pragma once template class IndicesDistributorType = sparse_detail::incremental_indices_distributor> struct cpu_sparse_allreduce_coll - : base_sparse_allreduce_coll, - host_data { + : base_sparse_allreduce_coll { using coll_base = base_sparse_allreduce_coll; using coll_strategy = typename coll_base::coll_strategy; @@ -33,140 +34,121 @@ struct cpu_sparse_allreduce_coll using coll_base::recv_vcount; using coll_base::fn_ctxs; - using coll_base::single_send_ibuf; - using coll_base::single_send_vbuf; - using coll_base::single_recv_ibuf; - using coll_base::single_recv_vbuf; - using coll_base::single_recv_icount; - using coll_base::single_recv_vcount; - using coll_base::single_fn_ctx; - cpu_sparse_allreduce_coll(bench_init_attr init_attr, size_t sbuf_size_modifier = 1, size_t rbuf_size_modifier = 1) - : coll_base(init_attr, comm().size()) { + : coll_base(init_attr, transport_data::get_comm_size()) { int result = 0; + int comm_size = transport_data::get_comm_size(); + size_t max_elem_count = base_coll::get_max_elem_count(); - size_t single_buf_max_elem_count = base_coll::get_single_buf_max_elem_count(); - - for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { - result = posix_memalign((void**)&send_ibufs[idx], - ALIGNMENT, - max_elem_count * sizeof(IType) * sbuf_size_modifier); - result |= posix_memalign((void**)&send_vbufs[idx], - ALIGNMENT, - max_elem_count * sizeof(VType) * sbuf_size_modifier); - result |= - posix_memalign((void**)&recv_ibufs[idx], - ALIGNMENT, - max_elem_count * sizeof(IType) * rbuf_size_modifier * comm().size()); - result |= - posix_memalign((void**)&recv_vbufs[idx], - ALIGNMENT, - max_elem_count * sizeof(VType) * rbuf_size_modifier * comm().size()); - if (result != 0) { - std::cerr << __FUNCTION__ << " - posix_memalign error: " << strerror(errno) - << ", on buffer idx: " << idx << std::endl; + + for (size_t rank_idx = 0; rank_idx < base_coll::get_ranks_per_proc(); rank_idx++) { + + for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + result = posix_memalign((void**)&(send_ibufs[idx][rank_idx]), + ALIGNMENT, + max_elem_count * sizeof(IType) * sbuf_size_modifier); + result |= posix_memalign((void**)&(send_vbufs[idx][rank_idx]), + ALIGNMENT, + max_elem_count * sizeof(VType) * sbuf_size_modifier); + result |= + posix_memalign((void**)&(recv_ibufs[idx][rank_idx]), + ALIGNMENT, + max_elem_count * sizeof(IType) * rbuf_size_modifier * comm_size); + result |= + posix_memalign((void**)&(recv_vbufs[idx][rank_idx]), + ALIGNMENT, + max_elem_count * sizeof(VType) * rbuf_size_modifier * comm_size); + if (result != 0) { + std::cerr << __FUNCTION__ << " - posix_memalign error: " << strerror(errno) + << ", on buffer idx: " << idx << std::endl; + } } - } - result = posix_memalign((void**)&single_send_ibuf, - ALIGNMENT, - single_buf_max_elem_count * sizeof(IType) * sbuf_size_modifier); - result |= posix_memalign((void**)&single_send_vbuf, - ALIGNMENT, - single_buf_max_elem_count * sizeof(VType) * sbuf_size_modifier); - - result |= posix_memalign( - (void**)&single_recv_ibuf, - ALIGNMENT, - single_buf_max_elem_count * sizeof(IType) * rbuf_size_modifier * comm().size()); - result |= posix_memalign( - (void**)&single_recv_vbuf, - ALIGNMENT, - single_buf_max_elem_count * sizeof(VType) * rbuf_size_modifier * comm().size()); - - for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { - std::memset(send_ibufs[idx], 0, max_elem_count * sizeof(IType)); - std::memset(send_vbufs[idx], 0, max_elem_count * sizeof(VType) * sbuf_size_modifier); - - std::memset(recv_ibufs[idx], - 0, - max_elem_count * sizeof(IType) * rbuf_size_modifier * comm().size()); - std::memset(recv_vbufs[idx], - 0, - max_elem_count * sizeof(VType) * rbuf_size_modifier * comm().size()); - } + for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + std::memset(send_ibufs[idx][rank_idx], 0, max_elem_count * sizeof(IType)); + std::memset(send_vbufs[idx][rank_idx], 0, max_elem_count * sizeof(VType) * sbuf_size_modifier); - std::memset( - single_send_ibuf, 0, single_buf_max_elem_count * sizeof(IType) * sbuf_size_modifier); - std::memset( - single_send_vbuf, 0, single_buf_max_elem_count * sizeof(VType) * sbuf_size_modifier); - - std::memset(single_recv_ibuf, - 0, - single_buf_max_elem_count * sizeof(IType) * rbuf_size_modifier * comm().size()); - std::memset(single_recv_vbuf, - 0, - single_buf_max_elem_count * sizeof(VType) * rbuf_size_modifier * comm().size()); - - for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { - fn_ctxs[idx].recv_ibuf = (void**)(&(recv_ibufs[idx])); - fn_ctxs[idx].recv_vbuf = (void**)(&(recv_vbufs[idx])); - fn_ctxs[idx].recv_ibuf_count = max_elem_count * rbuf_size_modifier * comm().size(); - fn_ctxs[idx].recv_vbuf_count = max_elem_count * rbuf_size_modifier * comm().size(); + std::memset(recv_ibufs[idx][rank_idx], + 0, + max_elem_count * sizeof(IType) * rbuf_size_modifier * comm_size); + std::memset(recv_vbufs[idx][rank_idx], + 0, + max_elem_count * sizeof(VType) * rbuf_size_modifier * comm_size); + } + + for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + fn_ctxs[idx][rank_idx].recv_ibuf = (void**)(&(recv_ibufs[idx][rank_idx])); + fn_ctxs[idx][rank_idx].recv_vbuf = (void**)(&(recv_vbufs[idx][rank_idx])); + fn_ctxs[idx][rank_idx].recv_ibuf_count = max_elem_count * rbuf_size_modifier * comm_size; + fn_ctxs[idx][rank_idx].recv_vbuf_count = max_elem_count * rbuf_size_modifier * comm_size; + } } - single_fn_ctx.recv_ibuf = (void**)(&single_recv_ibuf); - single_fn_ctx.recv_vbuf = (void**)(&single_recv_vbuf); - single_fn_ctx.recv_ibuf_count = - single_buf_max_elem_count * rbuf_size_modifier * comm().size(); - single_fn_ctx.recv_vbuf_count = - single_buf_max_elem_count * rbuf_size_modifier * comm().size(); } ~cpu_sparse_allreduce_coll() { - for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { - free(send_ibufs[idx]); - free(send_vbufs[idx]); - free(recv_ibufs[idx]); - free(recv_vbufs[idx]); - } - free(single_send_ibuf); - free(single_send_vbuf); - free(single_recv_ibuf); - free(single_recv_vbuf); + for (size_t rank_idx = 0; rank_idx < base_coll::get_ranks_per_proc(); rank_idx++) { + + for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + free(send_ibufs[idx][rank_idx]); + free(send_vbufs[idx][rank_idx]); + free(recv_ibufs[idx][rank_idx]); + free(recv_vbufs[idx][rank_idx]); + } + } } virtual void prepare(size_t elem_count) override { - this->init_distributor({ 0, elem_count }); - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - sparse_detail::fill_sparse_data(this->get_expected_recv_counts(elem_count), - *this->indices_distributor_impl, - elem_count, - send_ibufs[b_idx], - reinterpret_cast(send_vbufs[b_idx]), - reinterpret_cast(recv_vbufs[b_idx]), - fn_ctxs[b_idx].recv_vbuf_count, - recv_icount[b_idx], - recv_vcount[b_idx], - comm().rank()); + + auto& transport = transport_data::instance(); + auto& comms = transport.get_comms(); + size_t ranks_per_proc = base_coll::get_ranks_per_proc(); + + for (size_t rank_idx = 0; rank_idx < ranks_per_proc; rank_idx++) { + + auto& comm = comms[rank_idx]; + + this->init_distributor({ 0, elem_count }); + for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { + sparse_detail::fill_sparse_data(this->get_expected_recv_counts(elem_count), + *this->indices_distributor_impl, + elem_count, + send_ibufs[b_idx][rank_idx], + reinterpret_cast(send_vbufs[b_idx][rank_idx]), + reinterpret_cast(recv_vbufs[b_idx][rank_idx]), + fn_ctxs[b_idx][rank_idx].recv_vbuf_count, + recv_icount[b_idx], + recv_vcount[b_idx], + comm.rank()); + } } } virtual void finalize(size_t elem_count) override { - for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - sparse_detail::check_sparse_result(this->get_expected_recv_counts(elem_count), - elem_count, - send_ibufs[b_idx], - static_cast(send_vbufs[b_idx]), - recv_ibufs[b_idx], - static_cast(recv_vbufs[b_idx]), - recv_icount[b_idx], - recv_vcount[b_idx], - comm().size(), - comm().rank()); + + auto& transport = transport_data::instance(); + auto& comms = transport.get_comms(); + size_t ranks_per_proc = base_coll::get_ranks_per_proc(); + + for (size_t rank_idx = 0; rank_idx < ranks_per_proc; rank_idx++) { + + auto& comm = comms[rank_idx]; + + for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { + sparse_detail::check_sparse_result(this->get_expected_recv_counts(elem_count), + elem_count, + send_ibufs[b_idx][rank_idx], + static_cast(send_vbufs[b_idx][rank_idx]), + recv_ibufs[b_idx][rank_idx], + static_cast(recv_vbufs[b_idx][rank_idx]), + recv_icount[b_idx], + recv_vcount[b_idx], + comm.size(), + comm.rank()); + } } } @@ -174,43 +156,30 @@ struct cpu_sparse_allreduce_coll size_t buf_idx, const bench_exec_attr& attr, req_list_t& reqs) override { - coll_strategy::start_internal(comm(), - send_ibufs[buf_idx], - count, - send_vbufs[buf_idx], - count, - recv_ibufs[buf_idx], - recv_icount[buf_idx], - recv_vbufs[buf_idx], - recv_vcount[buf_idx], - attr, - reqs, - fn_ctxs[buf_idx], - coll_strategy::get_op_attr(attr)); - } - virtual void start_single(size_t count, - const bench_exec_attr& attr, - req_list_t& reqs) override { - coll_strategy::start_internal(comm(), - single_send_ibuf, - count, - single_send_vbuf, - count, - static_cast(single_recv_ibuf), - single_recv_icount, - reinterpret_cast(single_recv_vbuf), - single_recv_vcount, - attr, - reqs, - single_fn_ctx, - coll_strategy::get_op_attr(attr)); - } - - /* global communicator for cpu collectives */ - static ccl::communicator& comm() { - if (!host_data::comm_ptr) { + auto& transport = transport_data::instance(); + auto& comms = transport.get_comms(); + size_t ranks_per_proc = base_coll::get_ranks_per_proc(); + + for (size_t rank_idx = 0; rank_idx < ranks_per_proc; rank_idx++) { + + auto& comm = comms[rank_idx]; + + coll_strategy::start_internal(comm, + send_ibufs[buf_idx][rank_idx], + count, + send_vbufs[buf_idx][rank_idx], + count, + recv_ibufs[buf_idx][rank_idx], + recv_icount[buf_idx], + recv_vbufs[buf_idx][rank_idx], + recv_vcount[buf_idx], + attr, + reqs, + fn_ctxs[buf_idx][rank_idx], + coll_strategy::get_op_attr(attr)); } - return *host_data::comm_ptr; } }; + +#endif diff --git a/examples/benchmark/src/sparse_allreduce/sparse_allreduce_base.hpp b/examples/benchmark/src/sparse_allreduce/sparse_allreduce_base.hpp index 25a0a90df..8f809f68b 100644 --- a/examples/benchmark/src/sparse_allreduce/sparse_allreduce_base.hpp +++ b/examples/benchmark/src/sparse_allreduce/sparse_allreduce_base.hpp @@ -27,24 +27,16 @@ struct base_sparse_allreduce_coll using coll_base = base_coll; using coll_strategy = sparse_allreduce_strategy_impl; - std::vector send_ibufs; - std::vector send_vbufs; + std::vector> send_ibufs; + std::vector> send_vbufs; /* buffers from these arrays will be reallocated inside completion callback */ - std::vector recv_ibufs; - std::vector recv_vbufs; + std::vector> recv_ibufs; + std::vector> recv_vbufs; size_t* recv_icount = nullptr; size_t* recv_vcount = nullptr; - std::vector fn_ctxs; - - ITypeNonMod* single_send_ibuf = nullptr; - VTypeNonMod* single_send_vbuf = nullptr; - ITypeNonMod* single_recv_ibuf = nullptr; - VTypeNonMod* single_recv_vbuf = nullptr; - size_t single_recv_icount{}; - size_t single_recv_vcount{}; - sparse_allreduce_fn_ctx_t single_fn_ctx; + std::vector> fn_ctxs; base_sparse_allreduce_coll(bench_init_attr init_attr, size_t size) : base_coll(init_attr), @@ -65,6 +57,14 @@ struct base_sparse_allreduce_coll send_vbufs.resize(init_attr.buf_count); recv_ibufs.resize(init_attr.buf_count); recv_vbufs.resize(init_attr.buf_count); + + for (size_t idx = 0; idx < init_attr.buf_count; idx++) { + fn_ctxs[idx].resize(init_attr.ranks_per_proc); + send_ibufs[idx].resize(init_attr.ranks_per_proc); + send_vbufs[idx].resize(init_attr.ranks_per_proc); + recv_ibufs[idx].resize(init_attr.ranks_per_proc); + recv_vbufs[idx].resize(init_attr.ranks_per_proc); + } } virtual ~base_sparse_allreduce_coll() { @@ -78,6 +78,20 @@ struct base_sparse_allreduce_coll } ccl::datatype get_dtype() const override final { - return ccl::native_type_info::type>::ccl_datatype_value; + return ccl::native_type_info::type>::dtype; + } + + virtual void prepare_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + ASSERT(0, "unexpected"); + } + + virtual void finalize_internal(size_t elem_count, + ccl::communicator& comm, + ccl::stream& stream, + size_t rank_idx) override { + ASSERT(0, "unexpected"); } }; diff --git a/examples/benchmark/src/sparse_allreduce/sparse_allreduce_strategy.hpp b/examples/benchmark/src/sparse_allreduce/sparse_allreduce_strategy.hpp index dc6292475..c8d4a3c43 100644 --- a/examples/benchmark/src/sparse_allreduce/sparse_allreduce_strategy.hpp +++ b/examples/benchmark/src/sparse_allreduce/sparse_allreduce_strategy.hpp @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +#if 0 + #pragma once template @@ -23,7 +25,7 @@ struct type_printer { }; template <> -struct type_printer { +struct type_printer { static constexpr const char* sparse_class_name() { return "sparse_allreduce_bf16"; } @@ -130,7 +132,7 @@ struct sparse_allreduce_strategy_impl { using IndicesDistributor = IndicesDistributorType>; size_t v2i_ratio; - size_t comm_size; + int comm_size; const size_t minimal_indices_count = 1; void init_distributor(const std::pair& elem_range) { @@ -138,7 +140,7 @@ struct sparse_allreduce_strategy_impl { indices_distributor_impl.reset(new IndicesDistributor(elem_range.first, indices_count)); } - sparse_allreduce_strategy_impl(size_t v2i_ratio, size_t comm_size) + sparse_allreduce_strategy_impl(size_t v2i_ratio, int comm_size) : v2i_ratio(v2i_ratio), comm_size(comm_size) {} @@ -153,8 +155,8 @@ struct sparse_allreduce_strategy_impl { return std::tuple(indices_count, indices_count * vdim_count); } - template - void start_internal(comm_t& comm, + template + void start_internal(ccl::communicator& comm, const IType send_ibuf, size_t send_icount, const VType send_vbuf, @@ -197,3 +199,5 @@ struct sparse_allreduce_strategy_impl { std::unique_ptr indices_distributor_impl; }; + +#endif diff --git a/examples/benchmark/src/sparse_allreduce/sparse_detail.hpp b/examples/benchmark/src/sparse_allreduce/sparse_detail.hpp index 4f07238ac..9a33c16e8 100644 --- a/examples/benchmark/src/sparse_allreduce/sparse_detail.hpp +++ b/examples/benchmark/src/sparse_allreduce/sparse_detail.hpp @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +#if 0 + #pragma once #include @@ -89,14 +91,14 @@ void fill_sparse_data(const std::tuple& expected_recv_counts, std::fill(recv_vbuf, recv_vbuf + recv_vbuf_count, ValueType{ 0 }); } -// override for ccl::bf16 +// override for ccl::bfloat16 template void fill_sparse_data(const std::tuple& expected_recv_counts, IndicesDistributorType& generator, size_t elem_count, IndexType* send_ibuf, - ccl::bf16* send_vbuf, - ccl::bf16* recv_vbuf, + ccl::bfloat16* send_vbuf, + ccl::bfloat16* recv_vbuf, size_t recv_vbuf_count, size_t& recv_icount, size_t& recv_vcount, @@ -117,7 +119,7 @@ void fill_sparse_data(const std::tuple& expected_recv_counts, } } - std::fill(recv_vbuf, recv_vbuf + recv_vbuf_count, ccl::bf16{ 0 }); + std::fill(recv_vbuf, recv_vbuf + recv_vbuf_count, ccl::bfloat16{ 0 }); // convert send_vbuf from float to send_vbuf in bf16 convert_fp32_to_bf16_arrays(send_vbuf_from.data(), send_vbuf, elem_count); @@ -132,8 +134,8 @@ void check_sparse_result(const std::tuple& expected_recv_counts, const ValueType* recv_vbuf, size_t recv_icount, size_t recv_vcount, - size_t comm_size, - size_t comm_rank) { + int comm_size, + int comm_rank) { size_t indices_count, vdim_count; std::tie(indices_count, vdim_count) = expected_recv_counts; vdim_count = vdim_count / indices_count; @@ -150,7 +152,7 @@ void check_sparse_result(const std::tuple& expected_recv_counts, base_send_data.begin(), std::bind(std::minus(), std::placeholders::_1, comm_rank)); - for (size_t rank_index = 0; rank_index < comm_size; rank_index++) { + for (int rank_index = 0; rank_index < comm_size; rank_index++) { std::copy(send_ibuf, send_ibuf + indices_count, std::back_inserter(aggregated_indices)); std::transform(base_send_data.begin(), @@ -248,18 +250,18 @@ void check_sparse_result(const std::tuple& expected_recv_counts, } } -// override for ccl::bf16 +// override for ccl::bfloat16 template void check_sparse_result(const std::tuple& expected_recv_counts, size_t elem_count, const IndexType* send_ibuf, - const ccl::bf16* send_vbuf, + const ccl::bfloat16* send_vbuf, const IndexType* recv_ibuf, - const ccl::bf16* recv_vbuf, + const ccl::bfloat16* recv_vbuf, size_t recv_icount, size_t recv_vcount, - size_t comm_size, - size_t comm_rank) { + int comm_size, + int comm_rank) { size_t indices_count, vdim_count; std::tie(indices_count, vdim_count) = expected_recv_counts; vdim_count = vdim_count / indices_count; @@ -270,7 +272,7 @@ void check_sparse_result(const std::tuple& expected_recv_counts, std::vector aggregated_values; aggregated_values.reserve(indices_count * vdim_count * comm_size); - for (size_t rank_index = 0; rank_index < comm_size; rank_index++) { + for (int rank_index = 0; rank_index < comm_size; rank_index++) { std::copy(send_ibuf, send_ibuf + indices_count, std::back_inserter(aggregated_indices)); for (size_t i_idx = 0; i_idx < indices_count; i_idx++) { @@ -312,7 +314,7 @@ void check_sparse_result(const std::tuple& expected_recv_counts, // check received values std::vector recv_vbuf_float(recv_vcount, float{ 0 }); - convert_bf16_to_fp32_arrays(reinterpret_cast(const_cast(recv_vbuf)), + convert_bf16_to_fp32_arrays(reinterpret_cast(const_cast(recv_vbuf)), recv_vbuf_float.data(), recv_vcount); @@ -387,3 +389,5 @@ void check_sparse_result(const std::tuple& expected_recv_counts, } } } /* namespace sparse_detail */ + +#endif diff --git a/examples/benchmark/src/sparse_allreduce/sycl_sparse_allreduce_coll.hpp b/examples/benchmark/src/sparse_allreduce/sycl_sparse_allreduce_coll.hpp index efd2c5a1c..ea7507a89 100644 --- a/examples/benchmark/src/sparse_allreduce/sycl_sparse_allreduce_coll.hpp +++ b/examples/benchmark/src/sparse_allreduce/sycl_sparse_allreduce_coll.hpp @@ -13,25 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. */ +#if 0 + #pragma once #ifdef CCL_ENABLE_SYCL #include "sycl_coll.hpp" -template -struct sparse_allreduce_kernel_name_bufs {}; -template -struct sparse_allreduce_kernel_name_single_bufs {}; - template class IndicesDistributorType = sparse_detail::incremental_indices_distributor> struct sycl_sparse_allreduce_coll : base_sparse_allreduce_coll, cl::sycl::buffer, - IndicesDistributorType>, - device_data { + IndicesDistributorType> { using sycl_indices_t = cl::sycl::buffer; using sycl_values_t = cl::sycl::buffer; using coll_base = @@ -46,94 +42,51 @@ struct sycl_sparse_allreduce_coll : base_sparse_allreduce_coll(send_ibufs[idx])); - auto send_vbuf = (static_cast(send_vbufs[idx])); - - auto recv_ibuf = (static_cast(recv_ibufs[idx])); - auto recv_vbuf = (static_cast(recv_vbufs[idx])); - - auto send_ibuf_acc = send_ibuf->template get_access(cgh); - auto send_vbuf_acc = send_vbuf->template get_access(cgh); - auto recv_ibuf_acc = recv_ibuf->template get_access(cgh); - auto recv_vbuf_acc = recv_vbuf->template get_access(cgh); - - cgh.parallel_for> - (range<1>{max_elem_count*comm().size()}, [=](item<1> e_idx) - { - if (e_idx.get_linear_id() < max_elem_count) { - send_ibuf_acc[e_idx] = 0; - send_vbuf_acc[e_idx] = 0; - } - recv_ibuf_acc[e_idx] = 0; - recv_vbuf_acc[e_idx] = 0; - }); - }); - } - - single_send_ibuf = new sycl_indices_t(single_buf_max_elem_count * sbuf_size_modifier); - single_send_vbuf = new sycl_values_t(single_buf_max_elem_count * sbuf_size_modifier); - - single_recv_ibuf = - new sycl_indices_t(single_buf_max_elem_count * rbuf_size_modifier * comm().size()); - single_recv_vbuf = - new sycl_values_t(single_buf_max_elem_count * rbuf_size_modifier * comm().size()); - - device_data::sycl_queue.submit([&](handler& cgh) { - auto send_ibuf = (static_cast(single_send_ibuf)); - auto send_vbuf = (static_cast(single_send_vbuf)); - - auto recv_ibuf = (static_cast(single_recv_ibuf)); - auto recv_vbuf = (static_cast(single_recv_vbuf)); - - auto send_ibuf_acc = send_ibuf->template get_access(cgh); - auto send_vbuf_acc = send_vbuf->template get_access(cgh); - - auto recv_ibuf_acc = recv_ibuf->template get_access(cgh); - auto recv_vbuf_acc = recv_vbuf->template get_access(cgh); - - cgh.parallel_for> - (range<1>{ single_buf_max_elem_count * comm().size() }, [=](item<1> e_idx) - { - if (e_idx.get_linear_id() < single_buf_max_elem_count) { - send_ibuf_acc[e_idx] = 0; - send_vbuf_acc[e_idx] = 0; - } - recv_ibuf_acc[e_idx] = 0; - recv_vbuf_acc[e_idx] = 0; - }); - }); - - for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { - fn_ctxs[idx].recv_ibuf = (void**)(&(recv_ibufs[idx])); - fn_ctxs[idx].recv_vbuf = (void**)(&(recv_vbufs[idx])); - } - single_fn_ctx.recv_ibuf = (void**)(&single_recv_ibuf); - single_fn_ctx.recv_vbuf = (void**)(&single_recv_vbuf); + : coll_base(init_attr, transport_data::get_comm_size()) { + // size_t max_elem_count = base_coll::get_max_elem_count(); + + // int comm_size = transport_data::get_comm_size(); + + // for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + // send_ibufs[idx] = new sycl_indices_t(max_elem_count * sbuf_size_modifier); + // send_vbufs[idx] = new sycl_values_t(max_elem_count * sbuf_size_modifier); + + // recv_ibufs[idx] = + // new sycl_indices_t(max_elem_count * rbuf_size_modifier * comm_size); + // recv_vbufs[idx] = + // new sycl_values_t(max_elem_count * rbuf_size_modifier * comm_size); + + // stream.get_native().submit([&](handler& h) { + // auto send_ibuf = (static_cast(send_ibufs[idx])); + // auto send_vbuf = (static_cast(send_vbufs[idx])); + + // auto recv_ibuf = (static_cast(recv_ibufs[idx])); + // auto recv_vbuf = (static_cast(recv_vbufs[idx])); + + // auto send_ibuf_acc = send_ibuf->template get_access(h); + // auto send_vbuf_acc = send_vbuf->template get_access(h); + // auto recv_ibuf_acc = recv_ibuf->template get_access(h); + // auto recv_vbuf_acc = recv_vbuf->template get_access(h); + + // h.parallel_for(range<1>{max_elem_count*comm_size}, [=](item<1> e_idx) + // { + // if (e_idx.get_linear_id() < max_elem_count) { + // send_ibuf_acc[e_idx] = 0; + // send_vbuf_acc[e_idx] = 0; + // } + // recv_ibuf_acc[e_idx] = 0; + // recv_vbuf_acc[e_idx] = 0; + // }); + // }).wait(); + // } + + // for (size_t idx = 0; idx < base_coll::get_buf_count(); idx++) { + // fn_ctxs[idx].recv_ibuf = (void**)(&(recv_ibufs[idx])); + // fn_ctxs[idx].recv_vbuf = (void**)(&(recv_vbufs[idx])); + // } } virtual void prepare(size_t elem_count) override { @@ -143,6 +96,22 @@ struct sycl_sparse_allreduce_coll : base_sparse_allreduce_coll*>(single_send_ibuf), - count, - *reinterpret_cast*>(single_send_vbuf), - count, - *static_cast*>(single_recv_ibuf), - single_recv_icount, - *reinterpret_cast*>(single_recv_vbuf), - single_recv_vcount, - attr, - reqs, - single_fn_ctx, - stream(), - coll_strategy::get_op_attr(attr));*/ - } - - /* global communicator for cpu collectives */ - static ccl::communicator& comm() { - if (!device_data::comm_ptr) { - } - return *device_data::comm_ptr; - } - - static ccl::stream& stream() { - if (!device_data::stream_ptr) { - } - return *device_data::stream_ptr; - } }; #endif /* CCL_ENABLE_SYCL */ + +#endif diff --git a/examples/benchmark/src/transport_impl.hpp b/examples/benchmark/src/transport_impl.hpp index e63d525b1..3c6b81b9c 100644 --- a/examples/benchmark/src/transport_impl.hpp +++ b/examples/benchmark/src/transport_impl.hpp @@ -13,61 +13,138 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - #include -#include "base.hpp" +#ifdef CCL_ENABLE_SYCL +#include +#include "sycl_coll.hpp" +#endif /* CCL_ENABLE_SYCL */ + #include "transport.hpp" -transport_settings::transport_settings() { +transport_data::transport_data() { init_by_mpi(); + + service_comms.push_back(ccl::create_communicator(size, rank, kvs)); } -transport_settings::~transport_settings() { +transport_data::~transport_data() { deinit_by_mpi(); } -transport_settings &transport_settings::instance() { - static transport_settings inst; +transport_data& transport_data::instance() { + static transport_data inst; return inst; } -int transport_settings::get_rank() const noexcept { +size_t transport_data::get_comm_size() { + return transport_data::instance().get_comms()[0].size(); +} + +int transport_data::get_rank() const noexcept { return rank; } -int transport_settings::get_size() const noexcept { +int transport_data::get_size() const noexcept { return size; } -ccl::shared_ptr_class transport_settings::get_kvs() { +ccl::shared_ptr_class transport_data::get_kvs() { return kvs; } -void transport_settings::init_by_mpi() { - +void transport_data::init_by_mpi() { ccl::init(); MPI_Init(NULL, NULL); MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_rank(MPI_COMM_WORLD, &rank); - /* create CCL internal KVS */ ccl::shared_ptr_class kvs_candidate; ccl::kvs::address_type main_addr; if (rank == 0) { kvs_candidate = ccl::create_main_kvs(); main_addr = kvs_candidate->get_address(); - MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); } else { - MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); kvs_candidate = ccl::create_kvs(main_addr); } kvs = kvs_candidate; } -void transport_settings::deinit_by_mpi() { +void transport_data::deinit_by_mpi() { MPI_Finalize(); } + +ccl::communicator& transport_data::get_service_comm() { + return service_comms[0]; +} + +std::vector& transport_data::get_streams() { + return streams; +} + +std::vector& transport_data::get_bench_streams() { + return bench_streams; +} + +void transport_data::init_comms(user_options_t& options) { + int ranks_per_proc = options.ranks_per_proc; + + std::vector local_ranks; + for (int idx = 0; idx < ranks_per_proc; idx++) { + local_ranks.push_back(rank * ranks_per_proc + idx); + } + + ccl::context context = ccl::create_context(); + std::vector devices; + std::map r2d_map; + + if (options.backend == BACKEND_HOST) { + for (int idx = 0; idx < ranks_per_proc; idx++) { + streams.push_back(ccl::create_stream()); + bench_streams.push_back(ccl::create_stream()); + devices.push_back(ccl::create_device()); + } + } +#ifdef CCL_ENABLE_SYCL + else if (options.backend == BACKEND_SYCL) { + auto sycl_queues = create_sycl_queues(sycl_dev_names[options.sycl_dev_type], local_ranks); + ASSERT(!sycl_queues.empty(), "queues should contain at least one queue"); + ASSERT(ranks_per_proc == sycl_queues.size(), "ranks and queues sizes should match"); + + auto sycl_context = sycl_queues[0].get_context(); + context = ccl::create_context(sycl_context); + + for (int idx = 0; idx < ranks_per_proc; idx++) { + streams.push_back(ccl::create_stream(sycl_queues[idx])); + auto q = sycl::queue(sycl_queues[idx].get_context(), sycl_queues[idx].get_device()); + bench_streams.push_back(ccl::create_stream(q)); + devices.push_back(ccl::create_device(sycl_queues[idx].get_device())); + // TODO: multidevice unsupported yet + // ASSERT(sycl_context == sycl_queues[idx].get_context(), + // "all sycl queues should be from the same sycl context"); + } + } +#endif /* CCL_ENABLE_SYCL */ + else { + ASSERT(0, "unknown backend %d", (int)options.backend); + } + + for (int idx = 0; idx < ranks_per_proc; idx++) { + r2d_map.emplace(local_ranks[idx], devices[idx]); + } + + comms = ccl::create_communicators(size * ranks_per_proc, r2d_map, context, kvs); + + ASSERT((int)comms.size() == ranks_per_proc, + "unexpected comms size %zu, expected %d", + comms.size(), + ranks_per_proc); +} + +std::vector& transport_data::get_comms() { + return comms; +} diff --git a/examples/common/CMakeLists.txt b/examples/common/CMakeLists.txt index 9ff37b6e3..a5df2d71c 100644 --- a/examples/common/CMakeLists.txt +++ b/examples/common/CMakeLists.txt @@ -25,5 +25,6 @@ foreach(src ${sources}) target_link_libraries(${executable} PUBLIC rt) target_link_libraries(${executable} PUBLIC m) target_link_libraries(${executable} PUBLIC dl) - install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/common) + target_link_libraries(${executable} PUBLIC mpi) + install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/common OPTIONAL) endforeach() diff --git a/examples/common/version.cpp b/examples/common/version.cpp index 2cd8e4a19..bdbf47078 100644 --- a/examples/common/version.cpp +++ b/examples/common/version.cpp @@ -30,13 +30,14 @@ int main() { CCL_PRODUCT_FULL); printf("\nRuntime CCL library version:\nmajor: %d\nminor: %d\nupdate: %d\n" - "Product: %s\nBuild date: %s\nFull: %s\n", + "Product: %s\nBuild date: %s\nFull: %s\ncl_backend name: %s\n", version.major, version.minor, version.update, version.product_status, version.build_date, - version.full); + version.full, + version.cl_backend_name.c_str()); printf("\noneCCL specification version: %s\n", ONECCL_SPEC_VERSION); diff --git a/examples/cpu/CMakeLists.txt b/examples/cpu/CMakeLists.txt index 07da2ea50..58099643c 100644 --- a/examples/cpu/CMakeLists.txt +++ b/examples/cpu/CMakeLists.txt @@ -27,6 +27,6 @@ foreach(src ${sources}) target_link_libraries(${executable} PUBLIC stdc++) target_link_libraries(${executable} PRIVATE m) target_link_libraries(${executable} PUBLIC mpi) - install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/cpu) + install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/cpu OPTIONAL) endforeach() diff --git a/examples/cpu/allgatherv.cpp b/examples/cpu/allgatherv.cpp index 5ddc50e8a..393710f9f 100644 --- a/examples/cpu/allgatherv.cpp +++ b/examples/cpu/allgatherv.cpp @@ -15,14 +15,12 @@ */ #include "base.hpp" -void run_collective( - const char* cmd_name, - std::vector& send_buf, - std::vector& recv_buf, - std::vector& recv_counts, - const ccl::communicator& comm, - const ccl::allgatherv_attr& attr) { - +void run_collective(const char* cmd_name, + std::vector& send_buf, + std::vector& recv_buf, + std::vector& recv_counts, + const ccl::communicator& comm, + const ccl::allgatherv_attr& attr) { std::chrono::system_clock::duration exec_time{ 0 }; float expected = send_buf.size(); float received; @@ -31,13 +29,8 @@ void run_collective( for (size_t idx = 0; idx < ITERS; ++idx) { auto start = std::chrono::system_clock::now(); - ccl::allgatherv( - send_buf.data(), - send_buf.size(), - recv_buf.data(), - recv_counts, - comm, - attr).wait(); + ccl::allgatherv(send_buf.data(), send_buf.size(), recv_buf.data(), recv_counts, comm, attr) + .wait(); exec_time += std::chrono::system_clock::now() - start; } @@ -58,14 +51,12 @@ void run_collective( << ", us" << std::endl; } -void run_collective_vector( - const char* cmd_name, - std::vector& send_buf, - std::vector& recv_bufs, - std::vector& recv_counts, - const ccl::communicator& comm, - const ccl::allgatherv_attr& attr) { - +void run_collective_vector(const char* cmd_name, + std::vector& send_buf, + std::vector& recv_bufs, + std::vector& recv_counts, + const ccl::communicator& comm, + const ccl::allgatherv_attr& attr) { std::chrono::system_clock::duration exec_time{ 0 }; float expected = send_buf.size(); float received; @@ -74,13 +65,8 @@ void run_collective_vector( for (size_t idx = 0; idx < ITERS; ++idx) { auto start = std::chrono::system_clock::now(); - ccl::allgatherv( - send_buf.data(), - send_buf.size(), - recv_bufs, - recv_counts, - comm, - attr).wait(); + ccl::allgatherv(send_buf.data(), send_buf.size(), recv_bufs, recv_counts, comm, attr) + .wait(); exec_time += std::chrono::system_clock::now() - start; } @@ -104,7 +90,6 @@ void run_collective_vector( } int main() { - ccl::init(); int size, rank; @@ -114,50 +99,46 @@ int main() { ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; + auto kvs_attr = ccl::create_kvs_attr(); if (rank == 0) { - kvs = ccl::create_main_kvs(); + kvs = ccl::create_main_kvs(kvs_attr); main_addr = kvs->get_address(); MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); } else { MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); - kvs = ccl::create_kvs(main_addr); + kvs = ccl::create_kvs(main_addr, kvs_attr); } auto dev = ccl::create_device(); auto ctx = ccl::create_context(); - auto comm = ccl::create_communicator(size, rank, dev, ctx, kvs); + auto comm_attr = ccl::create_comm_attr(); + auto comm = ccl::create_communicator(size, rank, dev, ctx, kvs, comm_attr); auto attr = ccl::create_operation_attr(); - MSG_LOOP( - comm, - - std::vector send_buf(msg_count, static_cast(msg_count)); - std::vector recv_buf(comm.size() * msg_count, 0); - std::vector recv_bufs(comm.size(), nullptr); - std::vector recv_counts(comm.size(), msg_count); - - for (size_t idx = 0; idx < comm.size(); idx++) - recv_bufs[idx] = new float[msg_count]; - - attr.set(false); - run_collective( - "warmup_allgatherv", send_buf, recv_buf, recv_counts, comm, attr); - run_collective_vector( - "warmup_allgatherv_vector", send_buf, recv_bufs, recv_counts, comm, attr); - - attr.set(true); - run_collective( - "persistent_allgatherv", send_buf, recv_buf, recv_counts, comm, attr); - run_collective_vector( - "persistent_allgatherv_vector", send_buf, recv_bufs, recv_counts, comm, attr); - - attr.set(false); - run_collective( - "regular_allgatherv", send_buf, recv_buf, recv_counts, comm, attr); - run_collective_vector( - "regular_allgatherv_vector", send_buf, recv_bufs, recv_counts, comm, attr); - ); + MSG_LOOP(comm, + + std::vector send_buf(msg_count, static_cast(msg_count)); + std::vector recv_buf(comm.size() * msg_count, 0); + std::vector recv_bufs(comm.size(), nullptr); + std::vector recv_counts(comm.size(), msg_count); + + for (int idx = 0; idx < comm.size(); idx++) recv_bufs[idx] = new float[msg_count]; + + attr.set(false); + run_collective("warmup_allgatherv", send_buf, recv_buf, recv_counts, comm, attr); + run_collective_vector( + "warmup_allgatherv_vector", send_buf, recv_bufs, recv_counts, comm, attr); + + attr.set(true); + run_collective("persistent_allgatherv", send_buf, recv_buf, recv_counts, comm, attr); + run_collective_vector( + "persistent_allgatherv_vector", send_buf, recv_bufs, recv_counts, comm, attr); + + attr.set(false); + run_collective("regular_allgatherv", send_buf, recv_buf, recv_counts, comm, attr); + run_collective_vector( + "regular_allgatherv_vector", send_buf, recv_bufs, recv_counts, comm, attr);); MPI_Finalize(); diff --git a/examples/cpu/allreduce.cpp b/examples/cpu/allreduce.cpp index 14fe29b72..944ae3afb 100644 --- a/examples/cpu/allreduce.cpp +++ b/examples/cpu/allreduce.cpp @@ -27,12 +27,9 @@ void run_collective(const char* cmd_name, for (size_t idx = 0; idx < ITERS; ++idx) { auto start = std::chrono::system_clock::now(); - ccl::allreduce(send_buf.data(), - recv_buf.data(), - recv_buf.size(), - ccl::reduction::sum, - comm, - attr).wait(); + ccl::allreduce( + send_buf.data(), recv_buf.data(), recv_buf.size(), ccl::reduction::sum, comm, attr) + .wait(); exec_time += std::chrono::system_clock::now() - start; } @@ -53,7 +50,6 @@ void run_collective(const char* cmd_name, } int main() { - ccl::init(); int size, rank; diff --git a/examples/cpu/alltoallv.cpp b/examples/cpu/alltoallv.cpp index 639f06130..bca29d482 100644 --- a/examples/cpu/alltoallv.cpp +++ b/examples/cpu/alltoallv.cpp @@ -34,7 +34,7 @@ void run_collective(const char* cmd_name, std::fill(recv_buf.begin(), recv_buf.end(), 0); size_t elem_idx = 0; - for (size_t rank_idx = 0; rank_idx < comm.size(); rank_idx++) { + for (int rank_idx = 0; rank_idx < comm.size(); rank_idx++) { for (size_t idx = 0; idx < send_counts[rank_idx]; idx++) { send_buf[elem_idx] = comm.rank(); elem_idx++; @@ -42,19 +42,15 @@ void run_collective(const char* cmd_name, } auto start = std::chrono::system_clock::now(); - ccl::alltoallv(send_buf.data(), - send_counts, - recv_buf.data(), - recv_counts, - comm, - attr).wait(); + ccl::alltoallv(send_buf.data(), send_counts, recv_buf.data(), recv_counts, comm, attr) + .wait(); exec_time += std::chrono::system_clock::now() - start; } ccl::barrier(comm); size_t elem_idx = 0; - for (size_t rank_idx = 0; rank_idx < comm.size(); rank_idx++) { + for (int rank_idx = 0; rank_idx < comm.size(); rank_idx++) { int expected = rank_idx; for (size_t idx = 0; idx < recv_counts[rank_idx]; idx++) { if (recv_buf[elem_idx] != expected) { @@ -77,7 +73,6 @@ void run_collective(const char* cmd_name, } int main() { - ccl::init(); int size, rank; @@ -112,23 +107,20 @@ int main() { std::vector send_counts(comm.size()); std::vector recv_counts(comm.size()); - for (size_t idx = 0; idx < comm.size(); idx++) { + for (int idx = 0; idx < comm.size(); idx++) { int is_even_peer = (idx % 2 == 0) ? 1 : 0; send_counts[idx] = send_count; recv_counts[idx] = (is_even_peer) ? EVEN_RANK_SEND_COUNT : ODD_RANK_SEND_COUNT; } - MSG_LOOP( - comm, - attr.set(false); - run_collective( - "warmup alltoallv", send_buf, recv_buf, send_counts, recv_counts, comm, attr); - attr.set(true); - run_collective( - "persistent alltoallv", send_buf, recv_buf, send_counts, recv_counts, comm, attr); - attr.set(false); - run_collective( - "regular alltoallv", send_buf, recv_buf, send_counts, recv_counts, comm, attr);); + MSG_LOOP(comm, attr.set(false); run_collective( + "warmup alltoallv", send_buf, recv_buf, send_counts, recv_counts, comm, attr); + attr.set(true); + run_collective( + "persistent alltoallv", send_buf, recv_buf, send_counts, recv_counts, comm, attr); + attr.set(false); + run_collective( + "regular alltoallv", send_buf, recv_buf, send_counts, recv_counts, comm, attr);); MPI_Finalize(); diff --git a/examples/cpu/broadcast.cpp b/examples/cpu/broadcast.cpp index 22420e597..1bea2985d 100644 --- a/examples/cpu/broadcast.cpp +++ b/examples/cpu/broadcast.cpp @@ -31,11 +31,7 @@ void run_collective(const char* cmd_name, for (size_t idx = 0; idx < ITERS; ++idx) { auto start = std::chrono::system_clock::now(); - ccl::broadcast(buf.data(), - buf.size(), - COLL_ROOT, - comm, - attr).wait(); + ccl::broadcast(buf.data(), buf.size(), COLL_ROOT, comm, attr).wait(); exec_time += std::chrono::system_clock::now() - start; } @@ -61,7 +57,6 @@ void run_collective(const char* cmd_name, } int main() { - ccl::init(); int size, rank; diff --git a/examples/cpu/communicator.cpp b/examples/cpu/communicator.cpp index c12e1c31b..1fe2603db 100644 --- a/examples/cpu/communicator.cpp +++ b/examples/cpu/communicator.cpp @@ -61,86 +61,13 @@ void check_max_comm_number(const ccl::communicator& comm, } while (true); PRINT_BY_ROOT(comm, "created %zu communicators", user_comms); - // PRINT_BY_ROOT(comm, "try to create one more communicator, it should fail"); - - // try - // { - // auto comm = ccl::environment::instance().create_communicator(); - // printf("FAILED\n"); - // throw std::runtime_error("extra communicator has been created"); - // } - // catch(...) - // {} - - // PRINT_BY_ROOT(comm, "free one comm, try to create again"); - // size_t comm_idx = user_comms / 2; - - // try - // { - // communicators[comm_idx].reset(); - // } - // catch (...) - // { - // printf("FAILED\n"); - // throw std::runtime_error("can't free communicator"); - // } - - // try - // { - // communicators[comm_idx] = ccl::environment::instance().create_communicator(); - // } - // catch (...) - // { - // printf("FAILED\n"); - // throw std::runtime_error("can't create communicator after free"); - // } } -// void check_comm_create_identical_color() -// { -// size_t comm_size{}; -// size_t comm_rank{}; - -// PRINT_BY_ROOT(global_comm, -// "create comm as a copy of the global one by settings identical colors"); - -// ccl::comm_attr_t comm_attr = ccl::environment::instance().create_host_comm_attr(); -// comm_attr->set_value(123); -// auto comm = ccl::environment::instance().create_communicator(comm_attr); - -// comm_size = comm->size(); -// comm_rank = comm->rank(); - -// if (comm_size != global_comm->size()) -// { -// printf("FAILED\n"); -// throw std::runtime_error("mismatch in size, expected " + -// to_string(global_comm->size()) + -// " received " + to_string(comm_size)); -// } - -// if (comm_rank != global_comm->rank()) -// { -// printf("FAILED\n"); -// throw std::runtime_error("mismatch in rank, expected " + -// to_string(global_comm->rank()) + -// " received " + to_string(comm_rank)); -// } - -// PRINT_BY_ROOT(global_comm, -// "global comm: rank = %zu, size = %zu; " -// "new comm: rank = %zu, size = %zu", -// global_comm->rank(), global_comm->size(), -// comm_rank, comm_size); - -// check_allreduce_on_comm(comm); -// } - bool isPowerOfTwo(unsigned int x) { return x && !(x & (x - 1)); } -void check_comm_split_by_color(ccl::communicator& comm, int mpi_size, int mpi_rank) { +void check_comm_split_by_color(ccl::communicator& comm) { if (!isPowerOfTwo(comm.size())) { PRINT_BY_ROOT( comm, @@ -148,18 +75,18 @@ void check_comm_split_by_color(ccl::communicator& comm, int mpi_size, int mpi_ra return; } - for (size_t split_by = 2; split_by <= comm.size(); split_by *= 2) { + for (int split_by = 2; split_by <= comm.size(); split_by *= 2) { int color = comm.rank() % split_by; - auto attr = - ccl::create_comm_split_attr(ccl::attr_val(color)); + auto attr = ccl::preview::create_comm_split_attr( + ccl::attr_val(color)); auto new_comm = comm.split(attr); - size_t comm_size = comm.size(); - size_t new_comm_size = new_comm.size(); - size_t comm_rank = comm.rank(); - size_t new_comm_rank = new_comm.rank(); + int comm_size = comm.size(); + int new_comm_size = new_comm.size(); + int comm_rank = comm.rank(); + int new_comm_rank = new_comm.rank(); - size_t expected_new_comm_size = comm_size / split_by; + int expected_new_comm_size = comm_size / split_by; if (new_comm_size != expected_new_comm_size) { printf("FAILED (split)\n"); @@ -170,18 +97,97 @@ void check_comm_split_by_color(ccl::communicator& comm, int mpi_size, int mpi_ra } PRINT_BY_ROOT(comm, - "base comm: rank = %zu, size = %zu; " - "new comm: rank = %zu, size = %zu", + "base comm: rank = %d, size = %d; " + "new comm: rank = %d, size = %d", comm_rank, comm_size, new_comm_rank, new_comm_size); + PRINT_BY_ROOT(comm, " - allreduce test on a new communicator"); check_allreduce_on_comm(new_comm); } } +void check_comm_split_identical(ccl::communicator& comm) { + if (!isPowerOfTwo(comm.size())) { + PRINT_BY_ROOT( + comm, + "split comm by color: number of processes should be a power of 2 for test purpose"); + return; + } + + for (int split_by = 2; split_by <= comm.size(); split_by *= 2) { + int color = comm.rank() % split_by; + auto attr = ccl::preview::create_comm_split_attr( + ccl::attr_val(color)); + auto new_comm1 = comm.split(attr); + auto new_comm2 = comm.split(attr); + + if (new_comm1.size() != new_comm2.size()) { + printf("FAILED (split)\n"); + + throw std::runtime_error("the sizes of new communicators are not equal. Comm #1 size " + + std::to_string(new_comm1.size()) + " Comm #2 size " + + std::to_string(new_comm2.size())); + } + + if (new_comm1.rank() != new_comm2.rank()) { + printf("FAILED (split)\n"); + + throw std::runtime_error("the sizes of new communicators are not equal. Comm #1 rank " + + std::to_string(new_comm1.rank()) + " Comm #2 rank " + + std::to_string(new_comm2.rank())); + } + + PRINT_BY_ROOT(comm, + "comm #1: rank = %d, size = %d; " + "comm #2: rank = %d, size = %d", + new_comm1.rank(), + new_comm1.size(), + new_comm2.rank(), + new_comm2.size()); + } +} + +void check_comm_split_identical_color(ccl::communicator& comm) { + auto attr = + ccl::preview::create_comm_split_attr(ccl::attr_val(123)); + auto new_comm = comm.split(attr); + + if (new_comm.size() != comm.size()) { + printf("FAILED (split)\n"); + + throw std::runtime_error( + "the sizes of new communicator and base communicator are not equal. New comm size " + + std::to_string(new_comm.size()) + " Base comm size " + std::to_string(comm.size())); + } + + if (new_comm.rank() != comm.rank()) { + printf("FAILED (split)\n"); + + throw std::runtime_error( + "the sizes of new communicator and base communicator are not equal. New comm rank " + + std::to_string(new_comm.rank()) + " Base comm rank " + std::to_string(comm.rank())); + } + + PRINT_BY_ROOT(comm, + "base comm: rank = %d, size = %d; " + "new comm: rank = %d, size = %d", + comm.rank(), + new_comm.size(), + comm.rank(), + new_comm.size()); + + PRINT_BY_ROOT(comm, " - allreduce test on a new communicator"); + check_allreduce_on_comm(new_comm); +} + int main() { + /** + * The example only works with CCL_ATL_TRANSPORT=ofi + */ + setenv("CCL_ATL_TRANSPORT", "ofi", 0); ccl::init(); @@ -213,10 +219,16 @@ int main() { // PRINT_BY_ROOT(comm, "PASSED"); PRINT_BY_ROOT(comm, "\n- Communicator split test"); - check_comm_split_by_color(comm, mpi_size, mpi_rank); + check_comm_split_by_color(comm); + PRINT_BY_ROOT(comm, "PASSED"); + + PRINT_BY_ROOT(comm, "\n- Communicator identical split test"); + check_comm_split_identical(comm); PRINT_BY_ROOT(comm, "PASSED"); - // check_comm_create_identical_color(); + PRINT_BY_ROOT(comm, "\n- Communicator identical color split test"); + check_comm_split_identical_color(comm); + PRINT_BY_ROOT(comm, "PASSED"); MPI_Finalize(); diff --git a/examples/cpu/cpu_allgatherv_test.cpp b/examples/cpu/cpu_allgatherv_test.cpp index 66afedcb9..5083655db 100644 --- a/examples/cpu/cpu_allgatherv_test.cpp +++ b/examples/cpu/cpu_allgatherv_test.cpp @@ -22,7 +22,6 @@ using namespace std; int main() { - const size_t count = 128; size_t i = 0; @@ -64,11 +63,7 @@ int main() { } /* invoke allgatherv */ - ccl::allgatherv(send_buf.data(), - count, - recv_buf.data(), - recv_counts, - comm).wait(); + ccl::allgatherv(send_buf.data(), count, recv_buf.data(), recv_counts, comm).wait(); /* check correctness of recv_buf */ for (i = 0; i < count; i++) { diff --git a/examples/cpu/cpu_allreduce_bf16_test.cpp b/examples/cpu/cpu_allreduce_bf16_test.cpp index 38ce444cf..9ab1dc481 100644 --- a/examples/cpu/cpu_allreduce_bf16_test.cpp +++ b/examples/cpu/cpu_allreduce_bf16_test.cpp @@ -28,15 +28,15 @@ #define CHECK_ERROR(send_buf, recv_buf, comm) \ { \ /* https://www.mcs.anl.gov/papers/P4093-0713_1.pdf */ \ - size_t comm_size = comm.size(); \ + int comm_size = comm.size(); \ double log_base2 = log(comm_size) / log(2); \ double g = (log_base2 * BF16_PRECISION) / (1 - (log_base2 * BF16_PRECISION)); \ for (size_t i = 0; i < COUNT; i++) { \ - double expected = ((comm_size * (comm_size - 1) / 2) + ((float)(i) * comm_size)); \ + double expected = ((comm_size * (comm_size - 1) / 2) + ((float)(i)*comm_size)); \ double max_error = g * expected; \ if (fabs(max_error) < fabs(expected - recv_buf[i])) { \ printf( \ - "[%zu] got recv_buf[%zu] = %0.7f, but expected = %0.7f, max_error = %0.16f\n", \ + "[%d] got recv_buf[%zu] = %0.7f, but expected = %0.7f, max_error = %0.16f\n", \ comm.rank(), \ i, \ recv_buf[i], \ @@ -50,7 +50,6 @@ using namespace std; int main() { - const size_t count = 4096; size_t idx = 0; @@ -93,12 +92,9 @@ int main() { else { cout << "BF16 is enabled\n"; convert_fp32_to_bf16_arrays(send_buf, send_buf_bf16, count); - ccl::allreduce(send_buf_bf16, - recv_buf_bf16, - count, - ccl::datatype::bfloat16, - ccl::reduction::sum, - comm).wait(); + ccl::allreduce( + send_buf_bf16, recv_buf_bf16, count, ccl::datatype::bfloat16, ccl::reduction::sum, comm) + .wait(); convert_bf16_to_fp32_arrays(recv_buf_bf16, recv_buf, count); CHECK_ERROR(send_buf, recv_buf, comm); diff --git a/examples/cpu/cpu_allreduce_test.cpp b/examples/cpu/cpu_allreduce_test.cpp index 67a623d3f..e80963812 100644 --- a/examples/cpu/cpu_allreduce_test.cpp +++ b/examples/cpu/cpu_allreduce_test.cpp @@ -21,7 +21,6 @@ using namespace std; int main() { - const size_t count = 4096; size_t i = 0; @@ -64,11 +63,7 @@ int main() { } /* invoke allreduce */ - ccl::allreduce(send_buf, - recv_buf, - count, - ccl::reduction::sum, - comm).wait(); + ccl::allreduce(send_buf, recv_buf, count, ccl::reduction::sum, comm).wait(); /* check correctness of recv_buf */ for (i = 0; i < count; i++) { diff --git a/examples/cpu/custom_allreduce.cpp b/examples/cpu/custom_allreduce.cpp index 38495a1f1..5ccfc9829 100644 --- a/examples/cpu/custom_allreduce.cpp +++ b/examples/cpu/custom_allreduce.cpp @@ -24,7 +24,7 @@ int size, rank; ccl::datatype custom_dtype; -std::string global_match_id; +ccl::string_class global_match_id; typedef void (*expected_fn_t)(void*, size_t); typedef void (*fill_fn_t)(void*, size_t, size_t); @@ -49,28 +49,12 @@ typedef int (*check_fn_t)(void*, size_t, expected_fn_t); } while (0) /* primitive operations for custom datatype */ -void custom_2x(void* in_elem, void* out_elem) { - for (size_t idx = 0; idx < CUSTOM_REPEAT_COUNT; idx++) { - ((CUSTOM_BASE_DTYPE*)out_elem)[idx] = 2 * ((CUSTOM_BASE_DTYPE*)in_elem)[idx]; - } -} - void custom_sum(void* in_elem, void* inout_elem) { for (size_t idx = 0; idx < CUSTOM_REPEAT_COUNT; idx++) { ((CUSTOM_BASE_DTYPE*)inout_elem)[idx] += ((CUSTOM_BASE_DTYPE*)in_elem)[idx]; } } -void custom_to_char(void* in_elem, char* out_elem) { - *out_elem = ((CUSTOM_BASE_DTYPE*)in_elem)[0]; -} - -void custom_from_char(char* in_elem, void* out_elem) { - for (size_t idx = 0; idx < CUSTOM_REPEAT_COUNT; idx++) { - ((CUSTOM_BASE_DTYPE*)out_elem)[idx] = (CUSTOM_BASE_DTYPE)(*in_elem); - } -} - void custom_set(void* elem, size_t base_value) { for (size_t idx = 0; idx < CUSTOM_REPEAT_COUNT; idx++) { ((CUSTOM_BASE_DTYPE*)elem)[idx] = (CUSTOM_BASE_DTYPE)(base_value); @@ -200,140 +184,6 @@ void expected_custom_6(void* elem, size_t idx) { custom_set(elem, 2 * idx); } -void do_prologue_2x(const void* in_buf, - size_t in_count, - ccl::datatype in_dtype, - void** out_buf, - size_t* out_count, - ccl::datatype* out_dtype, - const ccl::fn_context* context) { - ASSERT((in_dtype == ccl::datatype::float32) || (in_dtype == custom_dtype), - "unexpected in_dtype %d", - static_cast(in_dtype)); - ASSERT(out_buf, "null ptr"); - ASSERT(context->offset == 0, "wrong offset for prologue func, should be 0"); - ASSERT(!strcmp(context->match_id, global_match_id.c_str()), "wrong match_id"); - - if (out_buf) - *out_buf = (void*)in_buf; - if (out_count) - *out_count = in_count; - if (out_dtype) - *out_dtype = in_dtype; - - for (size_t idx = 0; idx < in_count; idx++) { - if (in_dtype == ccl::datatype::float32) { - ((float*)(*out_buf))[idx] = ((float*)in_buf)[idx] * 2; - } - else if (in_dtype == custom_dtype) { - custom_2x((char*)in_buf + idx * CUSTOM_DTYPE_SIZE, - (char*)(*out_buf) + idx * CUSTOM_DTYPE_SIZE); - } - else { - ASSERT(0, "unexpected dtype %d", static_cast(in_dtype)); - } - } -} - -void do_epilogue_2x(const void* in_buf, - size_t in_count, - ccl::datatype in_dtype, - void* out_buf, - size_t* out_count, - ccl::datatype* out_dtype, - const ccl::fn_context* context) { - ASSERT((in_dtype == ccl::datatype::float32) || (in_dtype == custom_dtype), - "unexpected in_dtype %d", - static_cast(in_dtype)); - ASSERT(context->offset == 0, "wrong offset for epilogue func, should be 0"); - ASSERT(!strcmp(context->match_id, global_match_id.c_str()), "wrong match_id"); - - if (out_count) - *out_count = in_count; - - for (size_t idx = 0; idx < in_count; idx++) { - if (in_dtype == ccl::datatype::float32) { - ((float*)out_buf)[idx] = ((float*)in_buf)[idx] * 2; - } - else if (in_dtype == custom_dtype) { - custom_2x((char*)in_buf + idx * CUSTOM_DTYPE_SIZE, - (char*)out_buf + idx * CUSTOM_DTYPE_SIZE); - } - else { - ASSERT(0, "unexpected dtype %d", static_cast(in_dtype)); - } - } -} - -void do_prologue_dtype_to_char(const void* in_buf, - size_t in_count, - ccl::datatype in_dtype, - void** out_buf, - size_t* out_count, - ccl::datatype* out_dtype, - const ccl::fn_context* context) { - ASSERT((in_dtype == ccl::datatype::float32) || (in_dtype == custom_dtype), - "unexpected in_dtype %d", - static_cast(in_dtype)); - ASSERT(out_buf, "null ptr"); - ASSERT(context->offset == 0, "wrong offset for prologue func, should be 0"); - ASSERT(!strcmp(context->match_id, global_match_id.c_str()), "wrong match_id"); - - if (out_buf) - *out_buf = malloc(in_count); /* will be deallocated in do_epilogue_char_to_dtype */ - if (out_count) - *out_count = in_count; - if (out_dtype) - *out_dtype = ccl::datatype::int8; - - for (size_t idx = 0; idx < in_count; idx++) { - if (in_dtype == ccl::datatype::float32) { - float fval = ((float*)in_buf)[idx]; - int ival = (int)fval; - ((char*)(*out_buf))[idx] = (char)(ival % 256); - } - else if (in_dtype == custom_dtype) { - custom_to_char((char*)in_buf + idx * CUSTOM_DTYPE_SIZE, (char*)(*out_buf) + idx); - } - else { - ASSERT(0, "unexpected dtype %d", static_cast(in_dtype)); - } - } -} - -void do_epilogue_char_to_dtype(const void* in_buf, - size_t in_count, - ccl::datatype in_dtype, - void* out_buf, - size_t* out_count, - ccl::datatype out_dtype, - const ccl::fn_context* context) { - ASSERT(in_dtype == ccl::datatype::int8, "unexpected in_dtype %d", static_cast(in_dtype)); - ASSERT((out_dtype == ccl::datatype::float32) || (out_dtype == custom_dtype), - "unexpected out_dtype %d", - static_cast(out_dtype)); - ASSERT(context->offset == 0, "wrong offset for epilogue func, should be 0"); - ASSERT(!strcmp(context->match_id, global_match_id.c_str()), "wrong match_id"); - - if (out_count) - *out_count = in_count; - - for (size_t idx = 0; idx < in_count; idx++) { - if (out_dtype == ccl::datatype::float32) { - ((float*)out_buf)[idx] = (float)(((char*)in_buf)[idx]); - } - else if (out_dtype == custom_dtype) { - custom_from_char((char*)in_buf + idx, (char*)out_buf + idx * CUSTOM_DTYPE_SIZE); - } - else { - ASSERT(0, "unexpected dtype %d", static_cast(out_dtype)); - } - } - - if (in_buf != out_buf) - free((void*)in_buf); -} - void do_reduction_sum(const void* in_buf, size_t in_count, void* inout_buf, @@ -341,8 +191,7 @@ void do_reduction_sum(const void* in_buf, ccl::datatype dtype, const ccl::fn_context* context) { size_t dtype_size; - auto& env = ccl::environment::instance(); - dtype_size = env.get_datatype_size(dtype); + dtype_size = ccl::get_datatype_size(dtype); ASSERT((dtype == ccl::datatype::int8) || (dtype == ccl::datatype::float32) || (dtype == custom_dtype), @@ -379,8 +228,7 @@ void do_reduction_null(const void* in_buf, ccl::datatype dtype, const ccl::fn_context* context) { size_t dtype_size; - auto& env = ccl::environment::instance(); - dtype_size = env.get_datatype_size(dtype); + dtype_size = ccl::get_datatype_size(dtype); ASSERT((dtype == ccl::datatype::int8) || (dtype == ccl::datatype::float32) || (dtype == custom_dtype), @@ -416,8 +264,7 @@ void do_reduction_custom(const void* in_buf, ccl::datatype dtype, const ccl::fn_context* context) { size_t dtype_size; - auto& env = ccl::environment::instance(); - dtype_size = env.get_datatype_size(dtype); + dtype_size = ccl::get_datatype_size(dtype); ASSERT((dtype == ccl::datatype::float32) || (dtype == custom_dtype), "unexpected in_dtype %d", @@ -445,7 +292,6 @@ void do_reduction_custom(const void* in_buf, } int main() { - setenv("CCL_ATL_TRANSPORT", "ofi", 1); ccl::init(); @@ -454,22 +300,20 @@ int main() { MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_rank(MPI_COMM_WORLD, &rank); - auto& env = ccl::environment::instance(); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { - kvs = env.create_main_kvs(); + kvs = ccl::create_main_kvs(); main_addr = kvs->get_address(); MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); } else { MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); - kvs = env.create_kvs(main_addr); + kvs = ccl::create_kvs(main_addr); } - auto comm = env.create_communicator(size, rank, kvs); - auto attr = env.create_operation_attr(); + auto comm = ccl::create_communicator(size, rank, kvs); + auto attr = ccl::create_operation_attr(); float float_send_buf[MSG_SIZE_COUNT]; float float_recv_buf[MSG_SIZE_COUNT]; @@ -481,9 +325,9 @@ int main() { char custom_send_buf[MSG_SIZE_COUNT * CUSTOM_DTYPE_SIZE]; char custom_recv_buf[MSG_SIZE_COUNT * CUSTOM_DTYPE_SIZE]; - std::string base_match_id = attr.get(); + ccl::string_class base_match_id = attr.get(); attr.set(true); - std::string match_id; + ccl::string_class match_id; for (size_t idx = 0; idx < 2; idx++) { if (rank == 0) @@ -508,55 +352,12 @@ int main() { check_fn, expected_fn, "regular_allreduce"); - - /* prologue */ - expected_fn = (idx == 0) ? expected_float_2 : expected_custom_2; - match_id = base_match_id + "_prologue_" + std::to_string(idx); - attr.set(match_id); - attr.set((ccl::prologue_fn)do_prologue_2x); - RUN_COLLECTIVE( - ccl::allreduce( - send_buf, recv_buf, MSG_SIZE_COUNT, dtype, ccl::reduction::sum, comm, attr), - fill_fn, - check_fn, - expected_fn, - "allreduce_with_prologue"); - - /* epilogue */ - expected_fn = (idx == 0) ? expected_float_2 : expected_custom_2; - match_id = base_match_id + "_epilogue_" + std::to_string(idx); - attr.set(match_id); - attr.set((ccl::prologue_fn) nullptr); - attr.set((ccl::epilogue_fn)do_epilogue_2x); - RUN_COLLECTIVE( - ccl::allreduce( - send_buf, recv_buf, MSG_SIZE_COUNT, dtype, ccl::reduction::sum, comm, attr), - fill_fn, - check_fn, - expected_fn, - "allreduce_with_epilogue"); - - /* prologue and epilogue */ - expected_fn = (idx == 0) ? expected_float_4 : expected_custom_4; - match_id = base_match_id + "_prologue_and_epilogue_" + std::to_string(idx); - attr.set(match_id); - attr.set((ccl::prologue_fn)do_prologue_2x); - attr.set((ccl::epilogue_fn)do_epilogue_2x); - RUN_COLLECTIVE( - ccl::allreduce( - send_buf, recv_buf, MSG_SIZE_COUNT, dtype, ccl::reduction::sum, comm, attr), - fill_fn, - check_fn, - expected_fn, - "allreduce_with_prologue_and_epilogue"); } /* reduction_sum */ expected_fn = (idx == 0) ? expected_float_1 : expected_custom_1; match_id = base_match_id + "_reduction_sum_" + std::to_string(idx); attr.set(match_id); - attr.set((ccl::prologue_fn) nullptr); - attr.set((ccl::epilogue_fn) nullptr); attr.set((ccl::reduction_fn)do_reduction_sum); RUN_COLLECTIVE( ccl::allreduce( @@ -573,8 +374,6 @@ int main() { expected_fn = (idx == 0) ? expected_float_3 : expected_custom_3; match_id = base_match_id + "_reduction_null_" + std::to_string(idx); attr.set(match_id); - attr.set((ccl::prologue_fn) nullptr); - attr.set((ccl::epilogue_fn) nullptr); attr.set((ccl::reduction_fn)do_reduction_null); RUN_COLLECTIVE( ccl::allreduce( @@ -591,8 +390,6 @@ int main() { expected_fn = (idx == 0) ? expected_float_5 : expected_custom_5; match_id = base_match_id + "_reduction_custom_" + std::to_string(idx); attr.set(match_id); - attr.set((ccl::prologue_fn) nullptr); - attr.set((ccl::epilogue_fn) nullptr); attr.set((ccl::reduction_fn)do_reduction_custom); RUN_COLLECTIVE( ccl::allreduce( @@ -601,108 +398,6 @@ int main() { check_fn, expected_fn, "allreduce_with_reduction_custom"); - - /* prologue and reduction_sum */ - expected_fn = (idx == 0) ? expected_float_2 : expected_custom_2; - match_id = base_match_id + "_prologue_and_reduction_sum_" + std::to_string(idx); - attr.set(match_id); - attr.set((ccl::prologue_fn)do_prologue_2x); - attr.set((ccl::epilogue_fn) nullptr); - attr.set((ccl::reduction_fn)do_reduction_sum); - RUN_COLLECTIVE( - ccl::allreduce( - send_buf, recv_buf, MSG_SIZE_COUNT, dtype, ccl::reduction::custom, comm, attr), - fill_fn, - check_fn, - expected_fn, - "allreduce_with_prologue_and_reduction_sum"); - - /* epilogue and reduction_sum */ - expected_fn = (idx == 0) ? expected_float_2 : expected_custom_2; - match_id = base_match_id + "_epilogue_and_reduction_sum_" + std::to_string(idx); - attr.set(match_id); - attr.set((ccl::prologue_fn) nullptr); - attr.set((ccl::epilogue_fn)do_epilogue_2x); - attr.set((ccl::reduction_fn)do_reduction_sum); - RUN_COLLECTIVE( - ccl::allreduce( - send_buf, recv_buf, MSG_SIZE_COUNT, dtype, ccl::reduction::custom, comm, attr), - fill_fn, - check_fn, - expected_fn, - "allreduce_with_epilogue_and_reduction_sum"); - - /* prologue and epilogue and reduction_sum */ - expected_fn = (idx == 0) ? expected_float_4 : expected_custom_4; - match_id = - base_match_id + "_prologue_and_epilogue_and_reduction_sum_" + std::to_string(idx); - attr.set(match_id); - attr.set((ccl::prologue_fn)do_prologue_2x); - attr.set((ccl::epilogue_fn)do_epilogue_2x); - attr.set((ccl::reduction_fn)do_reduction_sum); - RUN_COLLECTIVE( - ccl::allreduce( - send_buf, recv_buf, MSG_SIZE_COUNT, dtype, ccl::reduction::custom, comm, attr), - fill_fn, - check_fn, - expected_fn, - "allreduce_with_prologue_and_epilogue_and_reduction_sum"); - - /* prologue and epilogue and reduction_null */ - if (size == 1) - expected_fn = (idx == 0) ? expected_float_4 : expected_custom_4; - else - expected_fn = (idx == 0) ? expected_float_3 : expected_custom_3; - match_id = - base_match_id + "_prologue_and_epilogue_and_reduction_null_" + std::to_string(idx); - attr.set(match_id); - attr.set((ccl::prologue_fn)do_prologue_2x); - attr.set((ccl::epilogue_fn)do_epilogue_2x); - attr.set((ccl::reduction_fn)do_reduction_null); - RUN_COLLECTIVE( - ccl::allreduce( - send_buf, recv_buf, MSG_SIZE_COUNT, dtype, ccl::reduction::custom, comm, attr), - fill_fn, - check_fn, - expected_fn, - "allreduce_with_prologue_and_epilogue_and_reduction_null"); - - /* prologue and epilogue and reduction_sum */ - expected_fn = (idx == 0) ? expected_float_1 : expected_custom_1; - match_id = - base_match_id + "_prologue_and_epilogue_and_reduction_sum2_" + std::to_string(idx); - attr.set(match_id); - attr.set( - (ccl::prologue_fn)do_prologue_dtype_to_char); - attr.set( - (ccl::epilogue_fn)do_epilogue_char_to_dtype); - attr.set((ccl::reduction_fn)do_reduction_sum); - RUN_COLLECTIVE( - ccl::allreduce( - send_buf, recv_buf, MSG_SIZE_COUNT, dtype, ccl::reduction::custom, comm, attr), - fill_fn, - check_fn, - expected_fn, - "allreduce_with_prologue_and_epilogue_and_reduction_sum2"); - - /* epilogue and reduction_custom */ - if (size == 1) - expected_fn = (idx == 0) ? expected_float_1 : expected_custom_1; - else - expected_fn = (idx == 0) ? expected_float_6 : expected_custom_6; - match_id = - base_match_id + "_prologue_and_epilogue_and_reduction_custom_" + std::to_string(idx); - attr.set(match_id); - attr.set((ccl::prologue_fn) nullptr); - attr.set((ccl::epilogue_fn)do_epilogue_2x); - attr.set((ccl::reduction_fn)do_reduction_custom); - RUN_COLLECTIVE( - ccl::allreduce( - send_buf, recv_buf, MSG_SIZE_COUNT, dtype, ccl::reduction::custom, comm, attr), - fill_fn, - check_fn, - expected_fn, - "allreduce_with_epilogue_and_reduction_custom"); } if (rank == 0) diff --git a/examples/cpu/datatype.cpp b/examples/cpu/datatype.cpp index 88863db66..40fdaf659 100644 --- a/examples/cpu/datatype.cpp +++ b/examples/cpu/datatype.cpp @@ -63,12 +63,12 @@ void check_allreduce(const ccl::communicator &comm) { for (size_t idx = 0; idx < max_dtype_count; idx++) { reqs[idx] = ccl::allreduce(send_bufs[idx].data(), - recv_bufs[idx].data(), - COUNT, - dtypes[idx], - ccl::reduction::custom, - comm, - attr); + recv_bufs[idx].data(), + COUNT, + dtypes[idx], + ccl::reduction::custom, + comm, + attr); } for (size_t idx = 0; idx < max_dtype_count; idx++) { @@ -125,6 +125,10 @@ void check_create_and_free() { } int main() { + /** + * The example only works with CCL_ATL_TRANSPORT=ofi + */ + setenv("CCL_ATL_TRANSPORT", "ofi", 0); ccl::init(); diff --git a/examples/cpu/external_kvs.cpp b/examples/cpu/external_kvs.cpp new file mode 100644 index 000000000..f2c272a5c --- /dev/null +++ b/examples/cpu/external_kvs.cpp @@ -0,0 +1,101 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "base.hpp" + +class external_kvs : public ccl::kvs_interface { +public: + external_kvs(ccl::shared_ptr_class kvs) : kvs(kvs) {} + + virtual ccl::vector_class get(const ccl::string_class& key) { + return kvs->get(key); + } + + virtual void set(const ccl::string_class& key, const ccl::vector_class& data) { + return kvs->set(key, data); + } + +private: + ccl::shared_ptr_class kvs; +}; + +void run_collective(const char* cmd_name, + const std::vector& send_buf, + std::vector& recv_buf, + const ccl::communicator& comm, + const ccl::allreduce_attr& attr) { + std::chrono::system_clock::duration exec_time{ 0 }; + float expected = (comm.size() - 1) * (static_cast(comm.size()) / 2); + + ccl::barrier(comm); + + for (size_t idx = 0; idx < ITERS; ++idx) { + auto start = std::chrono::system_clock::now(); + ccl::allreduce( + send_buf.data(), recv_buf.data(), recv_buf.size(), ccl::reduction::sum, comm, attr) + .wait(); + exec_time += std::chrono::system_clock::now() - start; + } + + for (size_t idx = 0; idx < recv_buf.size(); idx++) { + if (recv_buf[idx] != expected) { + fprintf(stderr, "idx %zu, expected %4.4f, got %4.4f\n", idx, expected, recv_buf[idx]); + + std::cout << "FAILED" << std::endl; + std::terminate(); + } + } + + ccl::barrier(comm); + + std::cout << "avg time of " << cmd_name << ": " + << std::chrono::duration_cast(exec_time).count() / ITERS + << ", us" << std::endl; +} + +int main() { + ccl::init_attr init_attr = ccl::create_init_attr(); + ccl::init(init_attr); + + int size, rank; + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + ccl::shared_ptr_class kvs; + ccl::kvs::address_type main_addr; + if (rank == 0) { + kvs = ccl::create_main_kvs(); + main_addr = kvs->get_address(); + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + } + else { + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + kvs = ccl::create_kvs(main_addr); + } + + auto ext_kvs = std::make_shared(kvs); + + auto comm = ccl::create_communicator(size, rank, ext_kvs); + auto attr = ccl::create_operation_attr(); + + MSG_LOOP(comm, std::vector send_buf(msg_count, static_cast(comm.rank())); + std::vector recv_buf(msg_count); + run_collective("regular allreduce", send_buf, recv_buf, comm, attr);); + + MPI_Finalize(); + + return 0; +} diff --git a/examples/cpu/priority_allreduce.cpp b/examples/cpu/priority_allreduce.cpp index af4d56043..6299b22a8 100644 --- a/examples/cpu/priority_allreduce.cpp +++ b/examples/cpu/priority_allreduce.cpp @@ -84,8 +84,7 @@ double msg_timers_stddev[MSG_COUNT]; size_t comp_delay_ms; -void do_iter(size_t iter_idx, - ccl::communicator& comm) { +void do_iter(size_t iter_idx, ccl::communicator& comm) { if (comm.rank() == 0) { printf("started iter %zu\n", iter_idx); fflush(stdout); @@ -104,7 +103,7 @@ void do_iter(size_t iter_idx, for (idx = 0; idx < MSG_COUNT; idx++) { sprintf(match_id, "%zu", idx); - attr.set(std::string(match_id)); + attr.set(ccl::string_class(match_id)); tmp_start_timer = when(); ccl::allreduce(msg_buffers[idx], @@ -113,7 +112,8 @@ void do_iter(size_t iter_idx, ccl::datatype::float32, ccl::reduction::sum, comm, - attr).wait(); + attr) + .wait(); tmp_stop_timer = when(); msg_iso_timers[idx] += (tmp_stop_timer - tmp_start_timer); } @@ -137,7 +137,7 @@ void do_iter(size_t iter_idx, sprintf(match_id, "%zu", idx); - attr.set(std::string(match_id)); + attr.set(ccl::string_class(match_id)); msg_starts[idx] = when(); tmp_start_timer = when(); @@ -192,7 +192,6 @@ void do_iter(size_t iter_idx, } int main() { - setenv("CCL_PRIORITY", "direct", 0); ccl::init(); @@ -285,17 +284,10 @@ int main() { std::vector recv_iter_timers(size); std::vector recv_iter_timers_counts(size, 1); - ccl::allgatherv(msg_timers, - MSG_COUNT, - recv_msg_timers.data(), - recv_msg_timers_counts, - comm).wait(); - - ccl::allgatherv(&iter_timer, - 1, - recv_iter_timers.data(), - recv_iter_timers_counts, - comm).wait(); + ccl::allgatherv(msg_timers, MSG_COUNT, recv_msg_timers.data(), recv_msg_timers_counts, comm) + .wait(); + + ccl::allgatherv(&iter_timer, 1, recv_iter_timers.data(), recv_iter_timers_counts, comm).wait(); if (rank == 0) { size_t rank_idx; diff --git a/examples/cpu/reduce.cpp b/examples/cpu/reduce.cpp index f49dad3dd..fca41d025 100644 --- a/examples/cpu/reduce.cpp +++ b/examples/cpu/reduce.cpp @@ -34,7 +34,8 @@ void run_collective(const char* cmd_name, ccl::reduction::sum, COLL_ROOT, comm, - attr).wait(); + attr) + .wait(); exec_time += std::chrono::system_clock::now() - start; } @@ -56,7 +57,6 @@ void run_collective(const char* cmd_name, } int main() { - ccl::init(); int size, rank; diff --git a/examples/cpu/reduce_scatter.cpp b/examples/cpu/reduce_scatter.cpp index 59d3c7c06..29accda5d 100644 --- a/examples/cpu/reduce_scatter.cpp +++ b/examples/cpu/reduce_scatter.cpp @@ -27,12 +27,9 @@ void run_collective(const char* cmd_name, for (size_t idx = 0; idx < ITERS; ++idx) { auto start = std::chrono::system_clock::now(); - ccl::reduce_scatter(send_buf.data(), - recv_buf.data(), - recv_buf.size(), - ccl::reduction::sum, - comm, - attr).wait(); + ccl::reduce_scatter( + send_buf.data(), recv_buf.data(), recv_buf.size(), ccl::reduction::sum, comm, attr) + .wait(); exec_time += std::chrono::system_clock::now() - start; } @@ -53,7 +50,6 @@ void run_collective(const char* cmd_name, } int main() { - ccl::init(); int size, rank; @@ -76,7 +72,8 @@ int main() { auto comm = ccl::create_communicator(size, rank, kvs); auto attr = ccl::create_operation_attr(); - MSG_LOOP(comm, std::vector send_buf(msg_count * comm.size(), static_cast(comm.rank())); + MSG_LOOP(comm, + std::vector send_buf(msg_count * comm.size(), static_cast(comm.rank())); std::vector recv_buf(msg_count); attr.set(false); run_collective("warmup reduce_scatter", send_buf, recv_buf, comm, attr); diff --git a/examples/cpu/unordered_allreduce.cpp b/examples/cpu/unordered_allreduce.cpp index aa6ce39b6..45c44a3ec 100644 --- a/examples/cpu/unordered_allreduce.cpp +++ b/examples/cpu/unordered_allreduce.cpp @@ -23,13 +23,17 @@ #include "base.hpp" int main() { + /** + * The example only works with CCL_ATL_TRANSPORT=ofi + */ + setenv("CCL_ATL_TRANSPORT", "ofi", 0); setenv("CCL_UNORDERED_COLL", "1", 1); const size_t buf_size = 1024; const size_t iter_count = 64; - std::vector match_ids; + std::vector match_ids; /* event, operation idx */ std::list> active_ops; @@ -76,36 +80,31 @@ int main() { } for (size_t iter = 0; iter < iter_count; ++iter) { - std::cout << "starting iter " << iter << std::endl; size_t start_idx = distribution(rand_dev); size_t rank_idx = start_idx; for (auto idx = 0; idx < size; idx++) { - - std::cout << "submit allreduce " << rank_idx - << " for match_id " << match_ids[rank_idx] << std::endl; + std::cout << "submit allreduce " << rank_idx << " for match_id " << match_ids[rank_idx] + << std::endl; attr.set(match_ids[rank_idx]); - active_ops.emplace_back( - ccl::allreduce(send_bufs[rank_idx].data(), - recv_bufs[rank_idx].data(), - buf_size, - ccl::reduction::sum, - comm, - attr), - rank_idx); + active_ops.emplace_back(ccl::allreduce(send_bufs[rank_idx].data(), + recv_bufs[rank_idx].data(), + buf_size, + ccl::reduction::sum, + comm, + attr), + rank_idx); rank_idx = (rank_idx + 1) % size; } while (!active_ops.empty()) { for (auto it = active_ops.begin(); it != active_ops.end();) { - if (it->first.test()) { - float expected = (it->second + 1) * size; printf( "completed allreduce %zu for match_id %s. Actual %3.2f, expected %3.2f\n", diff --git a/examples/include/base.hpp b/examples/include/base.hpp index 0b3152778..fc4239f5b 100644 --- a/examples/include/base.hpp +++ b/examples/include/base.hpp @@ -67,7 +67,7 @@ using namespace cl::sycl::access; START_MSG_SIZE_POWER, \ COLL_ROOT); \ std::vector msg_counts(MSG_SIZE_COUNT); \ - std::vector msg_match_ids(MSG_SIZE_COUNT); \ + std::vector msg_match_ids(MSG_SIZE_COUNT); \ for (size_t idx = 0; idx < MSG_SIZE_COUNT; ++idx) { \ msg_counts[idx] = 1u << (START_MSG_SIZE_POWER + idx); \ msg_match_ids[idx] = std::to_string(msg_counts[idx]); \ @@ -114,4 +114,12 @@ double when(void) { return (double)(tv.tv_sec - tv_base.tv_sec) * 1.0e6 + (double)(tv.tv_usec - tv_base.tv_usec); } +void mpi_finalize() { + int is_finalized = 0; + MPI_Finalized(&is_finalized); + + if (!is_finalized) + MPI_Finalize(); +} + #endif /* BASE_HPP */ diff --git a/examples/include/base_utils.hpp b/examples/include/base_utils.hpp index 539b76ae1..80330334d 100644 --- a/examples/include/base_utils.hpp +++ b/examples/include/base_utils.hpp @@ -179,6 +179,72 @@ void str_to_mset(const char* input, std::multiset& outpu output.insert(ccl::from_string(processes_input)); } } + +std::shared_ptr build_kvs(int mpi_rank) { + std::shared_ptr kvs_instance; + ccl::kvs::address_type main_addr; + if (mpi_rank == 0) { + kvs_instance = ccl::create_main_kvs(); + main_addr = kvs_instance->get_address(); + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + } + else { + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + kvs_instance = ccl::create_kvs(main_addr); + } + return kvs_instance; +} + +inline size_t take_mpi_rank_id_offest(const size_t mpi_rank_in_cluster, + const int mpi_size, + const size_t total_device_in_cluster) { + if (mpi_size > 2) { + throw std::runtime_error(std::string(__FUNCTION__) + + " - Only TWO processes support case !\n"); + } + return total_device_in_cluster; +} + +ccl::process_device_indices_type extract_indices_for_threads( + const size_t mpi_rank_in_cluster, + const int current_mpi_rank, + std::vector thread_gpu_affinity, + size_t& total_device_in_cluster, + std::vector& total_devices_in_process, + std::map>& devices_for_current_mpi_rank) { + ccl::process_device_indices_type thread_group_affinity; + + for (size_t thread_index = 0; thread_index < thread_gpu_affinity.size(); thread_index++) { + ccl::device_indices_type device_group_affinity; + str_to_mset( + thread_gpu_affinity[thread_index].c_str(), device_group_affinity, ','); + + std::cout << " Extracted GPU indices for thread by id: " << thread_index + << ", devices in threads count: " << device_group_affinity.size() << std::endl; + total_device_in_cluster += device_group_affinity.size(); + total_devices_in_process[mpi_rank_in_cluster] += device_group_affinity.size(); + thread_group_affinity[thread_index] = device_group_affinity; + + if (mpi_rank_in_cluster == static_cast(current_mpi_rank)) { + for (auto device_vendor_id : device_group_affinity) { + devices_for_current_mpi_rank[thread_index].push_back( + ccl::create_from_index(device_vendor_id).device); + } + } + } + return thread_group_affinity; +} + +std::vector set_union_devices_in_current_process( + const std::map>& devices_for_mpi_rank) { + std::vector devices_in_process; + for (auto& thread_devices : devices_for_mpi_rank) { + devices_in_process.insert( + devices_in_process.end(), thread_devices.second.begin(), thread_devices.second.end()); + } + return devices_in_process; +} + #endif //MULTI_GPU_SUPPORT } // namespace utils #endif /* BASE_UTILS_HPP */ diff --git a/examples/include/sycl_base.hpp b/examples/include/sycl_base.hpp index 027c60291..bad23be3a 100644 --- a/examples/include/sycl_base.hpp +++ b/examples/include/sycl_base.hpp @@ -18,13 +18,12 @@ #include #include -#include -#include #include #include #include #include +#include "base.hpp" #include "base_utils.hpp" #include "oneapi/ccl.hpp" @@ -55,7 +54,6 @@ inline bool has_accelerator() { } inline bool check_sycl_usm(queue& q, usm::alloc alloc_type) { - bool ret = true; device d = q.get_device(); @@ -73,67 +71,272 @@ inline bool check_sycl_usm(queue& q, usm::alloc alloc_type) { return ret; } -inline bool create_sycl_queue(int argc, - char* argv[], - queue& q) { +std::string get_preferred_gpu_platform_name() { + std::string backend; + std::string result; - auto exception_handler = [&](exception_list elist) { - for (exception_ptr const& e : elist) { - try { - rethrow_exception(e); + if (getenv("SYCL_BE") == nullptr) { + backend = "OpenCL"; + } + else if (getenv("SYCL_BE") != nullptr) { + if (std::strcmp(getenv("SYCL_BE"), "PI_LEVEL_ZERO") == 0) { + backend = "Level-Zero"; + } + else if (std::strcmp(getenv("SYCL_BE"), "PI_OPENCL") == 0) { + backend = "OpenCL"; + } + else { + throw std::runtime_error("invalid backend: " + std::string(getenv("SYCL_BE"))); + } + } + + auto plaform_list = sycl::platform::get_platforms(); + + for (const auto& platform : plaform_list) { + auto platform_name = platform.get_info(); + + auto devices = platform.get_devices(); + auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) { + return d.is_gpu(); + }); + + if (gpu_dev == devices.end()) { + // cout << "platform [" << platform_name + // << "] does not contain GPU devices, skipping\n"; + continue; + } + + if (platform_name.find(backend) == std::string::npos) { + // cout << "platform [" << platform_name + // << "] does not match with requested " + // << backend << ", skipping\n"; + continue; + } + + result = platform_name; + } + + if (result.empty()) + throw std::runtime_error("can not find preferred GPU platform"); + + return result; +} + +std::vector create_sycl_gpu_devices() { + constexpr char dev_prefix[] = "-- "; + constexpr char sub_dev_prefix[] = "---- "; + + std::vector result; + auto plaform_list = sycl::platform::get_platforms(); + auto preferred_platform_name = get_preferred_gpu_platform_name(); + + cout << "preferred platform: [" << preferred_platform_name << "]\n"; + + for (const auto& platform : plaform_list) { + auto platform_name = platform.get_info(); + + if (platform_name.compare(preferred_platform_name) != 0) + continue; + + cout << "platform: [" << platform_name << "]\n"; + + auto device_list = platform.get_devices(); + + for (const auto& device : device_list) { + auto device_name = device.get_info(); + + if (!device.is_gpu()) { + cout << dev_prefix << "device [" << device_name << "] is not GPU, skipping\n"; + continue; } - catch (std::exception const& e) { - cout << "failure\n"; + + auto part_props = device.get_info(); + + if (std::find(part_props.begin(), + part_props.end(), + info::partition_property::partition_by_affinity_domain) == + part_props.end()) { + cout << dev_prefix << "device [" << device_name + << "] does not support partition by affinity domain" + << ", use root device\n"; + result.push_back(device); + continue; + } + + auto part_affinity_domains = + device.get_info(); + + if (std::find(part_affinity_domains.begin(), + part_affinity_domains.end(), + info::partition_affinity_domain::next_partitionable) == + part_affinity_domains.end()) { + cout << dev_prefix << "device [" << device_name + << "] does not support next_partitionable affinity domain" + << ", use root device\n"; + result.push_back(device); + continue; + } + + cout << dev_prefix << "device [" << device_name << "] should provide " + << device.template get_info() + << " sub-devices\n"; + + auto sub_devices = + device.create_sub_devices( + info::partition_affinity_domain::next_partitionable); + + if (sub_devices.empty()) { + /* TODO: remove when SYCL/L0 sub-devices will be supported */ + cout << dev_prefix << "device [" << device_name << "] does not provide sub-devices" + << ", use root device\n"; + result.push_back(device); + continue; + } + + cout << dev_prefix << "device [" << device_name << "] provides " << sub_devices.size() + << " sub-devices\n"; + result.insert(result.end(), sub_devices.begin(), sub_devices.end()); + + for (auto idx = 0; idx < sub_devices.size(); idx++) { + cout << sub_dev_prefix << "sub-device " << idx << ": [" + << sub_devices[idx].get_info() << "]\n"; } } - }; + } - unique_ptr selector; - if (argc >= 2) { - if (strcmp(argv[1], "cpu") == 0) { - selector.reset(new cpu_selector()); + if (result.empty()) { + throw std::runtime_error("no GPU devices found"); + } + + cout << "found: " << result.size() << " GPU device(s)\n"; + + return result; +} + +std::vector create_sycl_queues(const std::string& device_type, + const std::vector& ranks) { + std::vector devices; + + try { + if ((device_type.compare("gpu") == 0) && has_gpu()) { + /* special handling to cover multi-tile case */ + devices = create_sycl_gpu_devices(); } - else if (strcmp(argv[1], "gpu") == 0) { - if (has_gpu()) { - selector.reset(new gpu_selector()); + else { + unique_ptr selector; + + if (device_type.compare("cpu") == 0) { + selector.reset(new cpu_selector()); } - else if (has_accelerator()) { + else if (device_type.compare("gpu") == 0) { + if (has_accelerator()) { + selector.reset(new host_selector()); + cout + << "Accelerator is the first in device list, but unavailable for multiprocessing " + << "host_selector has been created instead of default_selector.\n"; + } + else { + selector.reset(new default_selector()); + cout + << "GPU is unavailable, default_selector has been created instead of gpu_selector.\n"; + } + } + else if (device_type.compare("host") == 0) { selector.reset(new host_selector()); - cout - << "Accelerator is the first in device list, but unavailable for multiprocessing, host_selector has been created instead of default_selector.\n"; + } + else if (device_type.compare("default") == 0) { + if (!has_accelerator()) { + selector.reset(new default_selector()); + } + else { + selector.reset(new host_selector()); + cout + << "Accelerator is the first in device list, but unavailable for multiprocessing " + << " host_selector has been created instead of default_selector.\n"; + } } else { - selector.reset(new default_selector()); - cout - << "GPU is unavailable, default_selector has been created instead of gpu_selector.\n"; + throw std::runtime_error("Please provide device type: cpu | gpu | host | default"); } + devices.push_back(sycl::device(*selector)); } - else if (strcmp(argv[1], "host") == 0) { - selector.reset(new host_selector()); - } - else if (strcmp(argv[1], "default") == 0) { - if (!has_accelerator()) { - selector.reset(new default_selector()); + } + catch (...) { + throw std::runtime_error("No devices of requested type available"); + } + + if (devices.empty()) { + throw std::runtime_error("No devices of requested type available"); + } + + std::vector rank_devices; + + for (size_t idx = 0; idx < ranks.size(); idx++) { + rank_devices.push_back(devices[ranks[idx] % devices.size()]); + } + + if (rank_devices.empty()) { + throw std::runtime_error("No devices of requested type available for specified ranks"); + } + + sycl::context ctx; + + try { + ctx = sycl::context(rank_devices); + } + catch (sycl::runtime_error&) { + size_t preferred_idx = (ranks.back() / ranks.size()) % devices.size(); + cout << "Can not create context from all rank devices of type: " << device_type + << ", create context from single device, idx " << preferred_idx << "\n"; + ctx = sycl::context(devices[preferred_idx]); + } + + auto exception_handler = [&](exception_list elist) { + for (exception_ptr const& e : elist) { + try { + rethrow_exception(e); } - else { - selector.reset(new host_selector()); - cout - << "Accelerator is the first in device list, but unavailable for multiprocessing, host_selector has been created instead of default_selector.\n"; + catch (std::exception const& e) { + cout << "failure\n"; } } - else { - cerr << "Please provide device type: cpu | gpu | host | default\n"; + }; + + auto ctx_devices = ctx.get_devices(); + + if (ctx_devices.empty()) { + throw std::runtime_error("No devices of requested type available in context"); + } + + std::vector queues; + + cout << "Created context from devices of type: " << device_type << "\n"; + cout << "Devices [" << ctx_devices.size() << "]:\n"; + + for (size_t idx = 0; idx < ctx_devices.size(); idx++) { + cout << "[" << idx << "]: [" << ctx_devices[idx].get_info() << "]\n"; + queues.push_back(sycl::queue(ctx_devices[idx], exception_handler)); + } + + return queues; +} + +inline bool create_sycl_queue(int argc, char* argv[], int rank, queue& q) { + if (argc >= 2) { + try { + std::vector ranks = { rank }; + q = create_sycl_queues(argv[1], ranks)[0]; + return true; + } + catch (std::exception& e) { + cerr << e.what() << "\n"; return false; } - q = queue(*selector, exception_handler); - cout << "Requested device type: " << argv[1] << "\nRunning on " - << q.get_device().get_info() << "\n"; } else { cerr << "Please provide device type: cpu | gpu | host | default\n"; return false; } - return true; } bool handle_exception(queue& q) { @@ -166,18 +369,27 @@ usm::alloc usm_alloc_type_from_string(const string& str) { return it->second; } -template -struct buf_allocator { +std::pair take_usm_type(const int argc, char* str_type) { + std::map map_usm_type; + auto usm_alloc_type = usm::alloc::shared; + auto str_usm_alloc_type = "shared"; + if (argc > 1) { + str_usm_alloc_type = str_type; + usm_alloc_type = usm_alloc_type_from_string(str_usm_alloc_type); + } + + return std::make_pair(usm_alloc_type, str_usm_alloc_type); +} +template +struct buf_allocator { const size_t alignment = 64; - buf_allocator(queue& q) - : q(q) - {} + buf_allocator(queue& q) : q(q) {} ~buf_allocator() { for (auto& ptr : memory_storage) { - cl::sycl::free(ptr, q); + sycl::free(ptr, q); } } @@ -186,7 +398,7 @@ struct buf_allocator { if (alloc_type == usm::alloc::host) ptr = aligned_alloc_host(alignment, count, q); else if (alloc_type == usm::alloc::device) - ptr = aligned_alloc_device(alignment, count, q); + ptr = aligned_alloc_device(alignment, count, q); else if (alloc_type == usm::alloc::shared) ptr = aligned_alloc_shared(alignment, count, q); else @@ -195,10 +407,16 @@ struct buf_allocator { auto it = memory_storage.find(ptr); if (it != memory_storage.end()) { throw std::runtime_error(string(__PRETTY_FUNCTION__) + - " - allocator already owns this pointer"); + " - allocator already owns this pointer"); } memory_storage.insert(ptr); + auto pointer_type = sycl::get_pointer_type(ptr, q.get_context()); + if (pointer_type != alloc_type) + throw std::runtime_error( + string(__PRETTY_FUNCTION__) + "pointer_type " + std::to_string((int)pointer_type) + + " doesn't match with requested " + std::to_string((int)alloc_type)); + return ptr; } @@ -206,7 +424,7 @@ struct buf_allocator { auto it = memory_storage.find(ptr); if (it == memory_storage.end()) { throw std::runtime_error(string(__PRETTY_FUNCTION__) + - " - allocator doesn't own this pointer"); + " - allocator doesn't own this pointer"); } free(ptr, q); memory_storage.erase(it); diff --git a/examples/sycl/CMakeLists.txt b/examples/sycl/CMakeLists.txt index f2653b28a..412234d68 100644 --- a/examples/sycl/CMakeLists.txt +++ b/examples/sycl/CMakeLists.txt @@ -28,5 +28,5 @@ foreach(src ${sources}) target_link_libraries(${executable} PRIVATE m) target_link_libraries(${executable} PUBLIC mpi) target_link_libraries(${executable} PRIVATE ${COMPUTE_RUNTIME_TARGET_NAME}) - install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/sycl) + install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/sycl OPTIONAL) endforeach() diff --git a/examples/sycl/sycl_allgatherv_custom_usm_test.cpp b/examples/sycl/sycl_allgatherv_custom_usm_test.cpp index af707eb0a..9380a064f 100644 --- a/examples/sycl/sycl_allgatherv_custom_usm_test.cpp +++ b/examples/sycl/sycl_allgatherv_custom_usm_test.cpp @@ -24,7 +24,6 @@ struct custom_data_type { } __attribute__((packed)); int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -33,8 +32,14 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } @@ -50,10 +55,6 @@ int main(int argc, char *argv[]) { } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -89,17 +90,17 @@ int main(int argc, char *argv[]) { auto e = q.submit([&](auto &h) { accessor expected_buf_acc(expected_buf, h, write_only); h.parallel_for(send_count, [=](auto id) { - static_cast(send_buf)[id] = rank + 1; - for (int i = 0; i < size; i++) { - static_cast(recv_buf)[id] = -1; - expected_buf_acc[i * send_count + id] = i + 1; - } - }); + static_cast(send_buf)[id] = rank + 1; + for (int i = 0; i < size; i++) { + static_cast(recv_buf)[id] = -1; + expected_buf_acc[i * send_count + id] = i + 1; + } + }); }); /* create dependency vector */ vector events; - events.push_back(ccl::create_event(e)); + // events.push_back(ccl::create_event(e)); if (!handle_exception(q)) return -1; @@ -122,10 +123,10 @@ int main(int argc, char *argv[]) { accessor expected_buf_acc(expected_buf, h, read_only); accessor check_buf_acc(check_buf, h, write_only); h.parallel_for(size * send_count, [=](auto id) { - if (static_cast(recv_buf)[id] != expected_buf_acc[id]) { - check_buf_acc[id] = -1; - } - }); + if (static_cast(recv_buf)[id] != expected_buf_acc[id]) { + check_buf_acc[id] = -1; + } + }); }); if (!handle_exception(q)) @@ -145,7 +146,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_allgatherv_inplace_test.cpp b/examples/sycl/sycl_allgatherv_inplace_test.cpp index 18afb02b9..6fd7d7bc2 100644 --- a/examples/sycl/sycl_allgatherv_inplace_test.cpp +++ b/examples/sycl/sycl_allgatherv_inplace_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -31,16 +30,18 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -103,8 +104,8 @@ int main(int argc, char *argv[]) { accessor send_buf_acc(send_buf, h, read_only); accessor recv_buf_acc(recv_buf, h, write_only); h.parallel_for(send_buf_count, [=](auto id) { - recv_buf_acc[rbuf_idx + id] = send_buf_acc[id] + 1; - }); + recv_buf_acc[rbuf_idx + id] = send_buf_acc[id] + 1; + }); }); if (!handle_exception(q)) @@ -119,10 +120,10 @@ int main(int argc, char *argv[]) { accessor recv_buf_acc(recv_buf, h, write_only); accessor expected_buf_acc(expected_buf, h, read_only); h.parallel_for(recv_buf_count, [=](auto id) { - if (recv_buf_acc[id] != expected_buf_acc[id]) { - recv_buf_acc[id] = -1; - } - }); + if (recv_buf_acc[id] != expected_buf_acc[id]) { + recv_buf_acc[id] = -1; + } + }); }); if (!handle_exception(q)) @@ -142,7 +143,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_allgatherv_test.cpp b/examples/sycl/sycl_allgatherv_test.cpp index 7854a3a9d..7ac3a48b0 100644 --- a/examples/sycl/sycl_allgatherv_test.cpp +++ b/examples/sycl/sycl_allgatherv_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -29,16 +28,18 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -109,10 +110,10 @@ int main(int argc, char *argv[]) { accessor recv_buf_acc(recv_buf, h, write_only); accessor expected_buf_acc(expected_buf, h, read_only); h.parallel_for(size * count, [=](auto id) { - if (recv_buf_acc[id] != expected_buf_acc[id]) { - recv_buf_acc[id] = -1; - } - }); + if (recv_buf_acc[id] != expected_buf_acc[id]) { + recv_buf_acc[id] = -1; + } + }); }); if (!handle_exception(q)) @@ -132,7 +133,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_allgatherv_usm_test.cpp b/examples/sycl/sycl_allgatherv_usm_test.cpp index 0160deab6..a6013485a 100644 --- a/examples/sycl/sycl_allgatherv_usm_test.cpp +++ b/examples/sycl/sycl_allgatherv_usm_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -28,8 +27,14 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } @@ -45,10 +50,6 @@ int main(int argc, char *argv[]) { } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -91,7 +92,7 @@ int main(int argc, char *argv[]) { /* create dependency vector */ vector events; - events.push_back(ccl::create_event(e)); + // events.push_back(ccl::create_event(e)); if (!handle_exception(q)) return -1; @@ -105,10 +106,10 @@ int main(int argc, char *argv[]) { accessor expected_buf_acc(expected_buf, h, read_only); accessor check_buf_acc(check_buf, h, write_only); h.parallel_for(size * count, [=](auto id) { - if (recv_buf[id] != expected_buf_acc[id]) { - check_buf_acc[id] = -1; - } - }); + if (recv_buf[id] != expected_buf_acc[id]) { + check_buf_acc[id] = -1; + } + }); }); if (!handle_exception(q)) @@ -128,7 +129,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_allreduce_inplace_usm_test.cpp b/examples/sycl/sycl_allreduce_inplace_usm_test.cpp index 086eefc20..4c0605ba2 100644 --- a/examples/sycl/sycl_allreduce_inplace_usm_test.cpp +++ b/examples/sycl/sycl_allreduce_inplace_usm_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -28,8 +27,14 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } @@ -45,10 +50,6 @@ int main(int argc, char *argv[]) { } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -113,7 +114,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_allreduce_test.cpp b/examples/sycl/sycl_allreduce_test.cpp index c42ab5043..6200b3c33 100644 --- a/examples/sycl/sycl_allreduce_test.cpp +++ b/examples/sycl/sycl_allreduce_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -28,16 +27,18 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -113,7 +114,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_allreduce_usm_test.cpp b/examples/sycl/sycl_allreduce_usm_test.cpp index fe87a8ff1..e2fceb44a 100644 --- a/examples/sycl/sycl_allreduce_usm_test.cpp +++ b/examples/sycl/sycl_allreduce_usm_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -28,8 +27,14 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } @@ -45,10 +50,6 @@ int main(int argc, char *argv[]) { } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -115,7 +116,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_alltoall_test.cpp b/examples/sycl/sycl_alltoall_test.cpp index 3e8aede51..23e20629a 100644 --- a/examples/sycl/sycl_alltoall_test.cpp +++ b/examples/sycl/sycl_alltoall_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -29,16 +28,18 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -79,10 +80,9 @@ int main(int argc, char *argv[]) { /* open send_buf and modify it on the device side */ q.submit([&](auto &h) { accessor send_buf_acc(send_buf, h, write_only); - h.parallel_for(count * size, - [=](auto id) { - send_buf_acc[id] += 1; - }); + h.parallel_for(count * size, [=](auto id) { + send_buf_acc[id] += 1; + }); }); if (!handle_exception(q)) @@ -118,7 +118,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_alltoall_usm_test.cpp b/examples/sycl/sycl_alltoall_usm_test.cpp index 3a5c8c595..8fa744a97 100644 --- a/examples/sycl/sycl_alltoall_usm_test.cpp +++ b/examples/sycl/sycl_alltoall_usm_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -28,8 +27,14 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } @@ -45,10 +50,6 @@ int main(int argc, char *argv[]) { } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -75,11 +76,10 @@ int main(int argc, char *argv[]) { /* open buffers and modify them on the device side */ q.submit([&](auto &h) { - h.parallel_for(count * size, - [=](auto id) { - send_buf[id] = id / count + 1; - recv_buf[id] = -1; - }); + h.parallel_for(count * size, [=](auto id) { + send_buf[id] = id / count + 1; + recv_buf[id] = -1; + }); }); if (!handle_exception(q)) @@ -116,7 +116,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_alltoallv_test.cpp b/examples/sycl/sycl_alltoallv_test.cpp index cc9cf5181..fd9bd7810 100644 --- a/examples/sycl/sycl_alltoallv_test.cpp +++ b/examples/sycl/sycl_alltoallv_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -29,16 +28,18 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -82,10 +83,9 @@ int main(int argc, char *argv[]) { /* open send_buf and modify it on the device side */ q.submit([&](auto &h) { accessor send_buf_acc(send_buf, h, write_only); - h.parallel_for(count * size, - [=](auto id) { - send_buf_acc[id] += 1; - }); + h.parallel_for(count * size, [=](auto id) { + send_buf_acc[id] += 1; + }); }); if (!handle_exception(q)) @@ -97,12 +97,11 @@ int main(int argc, char *argv[]) { /* open recv_buf and check its correctness on the device side */ q.submit([&](auto &h) { accessor recv_buf_acc(recv_buf, h, write_only); - h.parallel_for(count * size, - [=](auto id) { - if (recv_buf_acc[id] != rank + 1) { - recv_buf_acc[id] = -1; - } - }); + h.parallel_for(count * size, [=](auto id) { + if (recv_buf_acc[id] != rank + 1) { + recv_buf_acc[id] = -1; + } + }); }); if (!handle_exception(q)) @@ -122,7 +121,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_alltoallv_usm_test.cpp b/examples/sycl/sycl_alltoallv_usm_test.cpp index f9f0519c4..5f23ad973 100644 --- a/examples/sycl/sycl_alltoallv_usm_test.cpp +++ b/examples/sycl/sycl_alltoallv_usm_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; int i = 0; @@ -28,8 +27,14 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } @@ -45,10 +50,6 @@ int main(int argc, char *argv[]) { } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -78,11 +79,10 @@ int main(int argc, char *argv[]) { /* open buffers and modify them on the device side */ q.submit([&](auto &h) { - h.parallel_for(count * size, - [=](auto id) { - send_buf[id] = id / count + 1; - recv_buf[id] = -1; - }); + h.parallel_for(count * size, [=](auto id) { + send_buf[id] = id / count + 1; + recv_buf[id] = -1; + }); }); if (!handle_exception(q)) @@ -119,7 +119,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/examples/sycl/sycl_broadcast_test.cpp b/examples/sycl/sycl_broadcast_test.cpp index 8594731a8..1976afdd5 100644 --- a/examples/sycl/sycl_broadcast_test.cpp +++ b/examples/sycl/sycl_broadcast_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; const size_t root_rank = 0; @@ -29,16 +28,18 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -67,7 +68,7 @@ int main(int argc, char *argv[]) { host_accessor send_buf_acc(buf, write_only); for (i = 0; i < count; i++) { if (rank == root_rank) - send_buf_acc[i] = rank; + send_buf_acc[i] = rank + 10; else send_buf_acc[i] = 0; } @@ -91,7 +92,7 @@ int main(int argc, char *argv[]) { q.submit([&](auto &h) { accessor recv_buf_acc(buf, h, write_only); h.parallel_for(count, [=](auto id) { - if (recv_buf_acc[id] != root_rank + 1) { + if (recv_buf_acc[id] != root_rank + 11) { recv_buf_acc[id] = -1; } }); @@ -101,20 +102,16 @@ int main(int argc, char *argv[]) { return -1; /* print out the result of the test on the host side */ - if (rank == root_rank) { - host_accessor recv_buf_acc(buf, read_only); - for (i = 0; i < count; i++) { - if (recv_buf_acc[i] == -1) { - cout << "FAILED\n"; - break; - } - } - if (i == count) { - cout << "PASSED\n"; + host_accessor recv_buf_acc(buf, read_only); + for (i = 0; i < count; i++) { + if (recv_buf_acc[i] == -1) { + cout << "FAILED\n"; + break; } } - - MPI_Finalize(); + if (i == count) { + cout << "PASSED\n"; + } return 0; } diff --git a/examples/sycl/sycl_broadcast_usm_test.cpp b/examples/sycl/sycl_broadcast_usm_test.cpp index cf64997d4..78b95af82 100644 --- a/examples/sycl/sycl_broadcast_usm_test.cpp +++ b/examples/sycl/sycl_broadcast_usm_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; const size_t root_rank = 0; @@ -29,8 +28,14 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } @@ -46,10 +51,6 @@ int main(int argc, char *argv[]) { } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -76,8 +77,11 @@ int main(int argc, char *argv[]) { /* open buffers and modify them on the device side */ q.submit([&](auto &h) { h.parallel_for(count, [=](auto id) { - if (id == root_rank) { - buf[id] = root_rank; + if (rank == root_rank) { + buf[id] = root_rank + 10; + } + else { + buf[id] = 0; } buf[id] += 1; }); @@ -94,7 +98,7 @@ int main(int argc, char *argv[]) { q.submit([&](auto &h) { accessor check_buf_acc(check_buf, h, write_only); h.parallel_for(count, [=](auto id) { - if (buf[id] != root_rank + 1) { + if (buf[id] != root_rank + 11) { check_buf_acc[id] = -1; } }); @@ -104,20 +108,16 @@ int main(int argc, char *argv[]) { return -1; /* print out the result of the test on the host side */ - if (rank == root_rank) { - host_accessor check_buf_acc(check_buf, read_only); - for (i = 0; i < count; i++) { - if (check_buf_acc[i] == -1) { - cout << "FAILED\n"; - break; - } - } - if (i == count) { - cout << "PASSED\n"; + host_accessor check_buf_acc(check_buf, read_only); + for (i = 0; i < count; i++) { + if (check_buf_acc[i] == -1) { + cout << "FAILED\n"; + break; } } - - MPI_Finalize(); + if (i == count) { + cout << "PASSED\n"; + } return 0; } diff --git a/examples/sycl/sycl_reduce_test.cpp b/examples/sycl/sycl_reduce_test.cpp index ae2739c5c..8d3230a2a 100644 --- a/examples/sycl/sycl_reduce_test.cpp +++ b/examples/sycl/sycl_reduce_test.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace sycl; int main(int argc, char *argv[]) { - const size_t count = 10 * 1024 * 1024; const size_t root_rank = 0; @@ -29,16 +28,18 @@ int main(int argc, char *argv[]) { ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + queue q; - if (!create_sycl_queue(argc, argv, q)) { + if (!create_sycl_queue(argc, argv, rank, q)) { return -1; } /* create kvs */ - MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - ccl::shared_ptr_class kvs; ccl::kvs::address_type main_addr; if (rank == 0) { @@ -124,7 +125,5 @@ int main(int argc, char *argv[]) { } } - MPI_Finalize(); - return 0; } diff --git a/include/oneapi/ccl.hpp b/include/oneapi/ccl.hpp index 8a0d3356b..cb60ff453 100644 --- a/include/oneapi/ccl.hpp +++ b/include/oneapi/ccl.hpp @@ -15,9 +15,9 @@ */ #pragma once -#include "oneapi/ccl/ccl_environment.hpp" - -#include "oneapi/ccl/ccl_api_functions.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/environment.hpp" +#include "oneapi/ccl/api_functions.hpp" namespace ccl {} namespace oneapi { diff --git a/include/oneapi/ccl/ccl_aliases.hpp b/include/oneapi/ccl/aliases.hpp similarity index 96% rename from include/oneapi/ccl/ccl_aliases.hpp rename to include/oneapi/ccl/aliases.hpp index 9cb28d212..74031d7d8 100644 --- a/include/oneapi/ccl/ccl_aliases.hpp +++ b/include/oneapi/ccl/aliases.hpp @@ -26,6 +26,8 @@ #include #include +#include "oneapi/ccl/string.hpp" + namespace ccl { template > using vector_class = std::vector; @@ -33,7 +35,7 @@ using vector_class = std::vector; template using array_class = std::array; -using string_class = std::string; +using string_class = ccl::string; template using function_class = std::function; diff --git a/include/oneapi/ccl/ccl_api_functions.hpp b/include/oneapi/ccl/api_functions.hpp similarity index 55% rename from include/oneapi/ccl/ccl_api_functions.hpp rename to include/oneapi/ccl/api_functions.hpp index f2bedb1dc..6a1c8c961 100644 --- a/include/oneapi/ccl/ccl_api_functions.hpp +++ b/include/oneapi/ccl/api_functions.hpp @@ -13,1191 +13,1201 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -#ifndef CCL_PRODUCT_FULL -#error "Do not include this file directly. Please include 'ccl.hpp'" -#endif - -namespace ccl { - -/******************** INIT ********************/ - -/** - * Initializes the library. Optional for invocation. - */ -void init(); - -/** - * Retrieves the library version - */ -library_version get_library_version(); - - -/******************** DATATYPE ********************/ - -/** - * Creates a datatype attribute object, which may used to register custom datatype - * @return an attribute object - */ -template -datatype_attr CCL_API create_datatype_attr(attr_value_pair_t&&... avps) { - return environment::instance().create_datatype_attr(std::forward(avps)...); -} - -/** - * Registers custom datatype to be used in communication operations - * @param attr datatype attributes - * @return datatype handle - */ -datatype register_datatype(const datatype_attr& attr); - -/** - * Deregisters custom datatype - * @param dtype custom datatype handle - */ -void deregister_datatype(datatype dtype); - -/** - * Retrieves a datatype size in bytes - * @param dtype datatype handle - * @return datatype size - */ -size_t get_datatype_size(datatype dtype); - - -/******************** KVS ********************/ - -/** - * Creates a main key-value store. - * It's address should be distributed using out of band communication mechanism - * and be used to create key-value stores on other ranks. - * @return kvs object - */ -shared_ptr_class create_main_kvs(); - -/** - * Creates a new key-value store from main kvs address - * @param addr address of main kvs - * @return kvs object - */ -shared_ptr_class create_kvs(const kvs::address_type& addr); - - -/******************** DEVICE ********************/ - -/** - * Creates a new device from @native_device_type - * @param native_device the existing handle of device - * @return device object - */ -device create_device(); - -template ()>::type> -device create_device(native_device_type&& native_device) { - return environment::instance().create_device(std::forward(native_device)); -} - -template -device create_device_from_attr(typename unified_device_type::ccl_native_t dev, - attr_value_pair_t&&... avps) { - return environment::instance().create_device_from_attr( - dev, std::forward(avps)...); -} - - -/******************** CONTEXT ********************/ - -/** - * Creates a new context from @native_device_contex_type - * @param native_device_context the existing handle of context - * @return context object - */ -context create_context(); - -template ()>::type> -context create_context(native_device_context_type&& native_device_context) { - return environment::instance().create_context(std::forward(native_device_context)); -} - -template -context create_context_from_attr(typename unified_device_context_type::ccl_native_t ctx, - attr_value_pair_t&&... avps) { - return environment::instance().create_context_from_attr( - ctx, std::forward(avps)...); -} - -/******************** EVENT ********************/ - -/** - * Creates a new event from @native_event_type - * @param native_event the existing handle of event - * @return event object - */ -template ()>::type> -event create_event(event_type& native_event) { - return environment::instance().create_event(native_event); -} - - -/******************** STREAM ********************/ - -/** - * Creates a new stream from @native_stream_type - * @param native_stream the existing handle of stream - * @return stream object - */ -stream create_stream(); - -template ()>::type> -stream create_stream(native_stream_type& native_stream) { - return environment::instance().create_stream(native_stream); -} - -template ()>::type> -stream create_stream(native_stream_type& native_stream, native_context_type& native_ctx) { - return environment::instance().create_stream(native_stream, native_ctx); -} - -template -stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, - attr_value_pair_t&&... avps) { - return environment::instance().create_stream_from_attr( - device, std::forward(avps)...); -} - -template -stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, - typename unified_device_context_type::ccl_native_t context, - attr_value_pair_t&&... avps) { - return environment::instance().create_stream_from_attr( - device, context, std::forward(avps)...); -} - - -/******************** COMMUNICATOR ********************/ - -template -comm_split_attr create_comm_split_attr(attr_value_pair_t&&... avps) { - return environment::instance().create_comm_split_attr( - std::forward(avps)...); -} - -namespace preview { - -/** - * Splits device communicators according to attributes. - * @param attrs split attributes for local communicators - * @return vector of device communicators - */ -vector_class split_device_communicators( - const vector_class>& attrs); - - -/** - * Creates a new communicator with externally provided size, rank and kvs. - * Implementation is platform specific and non portable. - * @return communicator - */ -communicator create_communicator(); - -/** - * Creates a new communicator with user supplied size and kvs. - * Rank will be assigned automatically. - * @param size user-supplied total number of ranks - * @param kvs key-value store for ranks wire-up - * @return communicator - */ -communicator create_communicator(size_t size, shared_ptr_class kvs); - -} // namespace preview - - -/** - * Creates a new communicator with user supplied size, rank and kvs. - * @param size user-supplied total number of ranks - * @param rank user-supplied rank - * @param kvs key-value store for ranks wire-up - * @return communicator - */ -communicator create_communicator(size_t size, - size_t rank, - shared_ptr_class kvs); - -/** - * Creates a new communicators with user supplied size, locao devices and kvs. - * Ranks will be assigned automatically. - * @param size user-supplied total number of ranks - * @param local_devices user-supplied device objects for local ranks - * @param context context containing the devices - * @param kvs key-value store for ranks wire-up - * @return vector of communicators - */ -template -vector_class create_communicators( - size_t size, - const vector_class& local_devices, - ContextType& context, - shared_ptr_class kvs) { - return environment::instance().create_communicators( - size, local_devices, context, kvs); -} - -/** - * Creates a new communicators with user supplied size, ranks, local device-rank mapping and kvs. - * @param size user-supplied total number of ranks - * @param local_rank_device_map user-supplied mapping of local ranks on devices - * @param context context containing the devices - * @param kvs key-value store for ranks wire-up - * @return vector of communicators - */ -template -vector_class create_communicators( - size_t size, - const vector_class>& local_rank_device_map, - ContextType& context, - shared_ptr_class kvs) { - return environment::instance().create_communicators( - size, local_rank_device_map, context, kvs); -} - -template -vector_class create_communicators( - size_t size, - const map_class& local_rank_device_map, - ContextType& context, - shared_ptr_class kvs) { - return environment::instance().create_communicators( - size, local_rank_device_map, context, kvs); -} - -template -communicator create_communicator( - size_t size, - rank_t rank, - DeviceType& device, - ContextType& context, - shared_ptr_class kvs) { - - auto comms = environment::instance().create_communicators( - size, ccl::vector_class>{{rank,device}}, context, kvs); - - if (comms.size() != 1) - throw ccl::exception("unexpected comm vector size"); - - return std::move(comms[0]); -} - - -/******************** OPERATION ********************/ - -/** - * Creates an operation attribute object, which may used to customize communication operation - * @return an attribute object - */ -template -coll_attribute_type CCL_API create_operation_attr(attr_value_pair_t&&... avps) { - return environment::instance().create_operation_attr( - std::forward(avps)...); -} - -/** - * Allgatherv is a collective communication operation that collects data - * from all the ranks within a communicator into a single buffer. - * Different ranks may contribute segments of different sizes. - * The resulting data in the output buffer must be the same for each rank. - */ - -/** - * @param send_buf the buffer with @c send_count elements of @c dtype that stores local data to be gathered - * @param send_count the number of elements of type @c dtype in @c send_buf - * @param recv_buf [out] the buffer to store gathered result, should be large enough to hold values from all ranks - * @param recv_bufs [out] array of buffers to store gathered result, one buffer per each rank - * @param recv_counts array with the number of elements of type @c dtype to be received from each rank - * @param dtype the datatype of elements in @c send_buf and @c recv_buf - * @param comm the communicator for which the operation will be performed - * @param stream an optional stream associated with the operation - * @param attr optional attributes to customize operation - * @param deps an optional vector of the events that the operation should depend on - * @return @ref ccl::event an object to track the progress of the operation - */ -event allgatherv(const void* send_buf, - size_t send_count, - void* recv_buf, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const stream& stream, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -event allgatherv(const void* send_buf, - size_t send_count, - void* recv_buf, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -event allgatherv(const void* send_buf, - size_t send_count, - const vector_class& recv_bufs, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const stream& stream, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -event allgatherv(const void* send_buf, - size_t send_count, - const vector_class& recv_bufs, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allgatherv(const BufferType* send_buf, - size_t send_count, - BufferType* recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const stream& stream, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allgatherv(const BufferType* send_buf, - size_t send_count, - BufferType* recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allgatherv(const BufferType* send_buf, - size_t send_count, - vector_class& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const stream& stream, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allgatherv(const BufferType* send_buf, - size_t send_count, - vector_class& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allgatherv(const BufferObjectType& send_buf, - size_t send_count, - BufferObjectType& recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const stream& stream, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allgatherv(const BufferObjectType& send_buf, - size_t send_count, - BufferObjectType& recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allgatherv(const BufferObjectType& send_buf, - size_t send_count, - vector_class>& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const stream& stream, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allgatherv(const BufferObjectType& send_buf, - size_t send_count, - vector_class>& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const allgatherv_attr& attr = default_allgatherv_attr, - const vector_class& deps = {}); - -/** - * Allreduce is a collective communication operation that performs the global reduction operation - * on values from all ranks of communicator and distributes the result back to all ranks. - */ - -/** - * @param send_buf the buffer with @c count elements of @c dtype that stores local data to be reduced - * @param recv_buf [out] the buffer to store reduced result, must have the same dimension as @c send_buf - * @param count the number of elements of type @c dtype in @c send_buf and @c recv_buf - * @param dtype the datatype of elements in @c send_buf and @c recv_buf - * @param rtype the type of the reduction operation to be applied - * @param comm the communicator for which the operation will be performed - * @param stream an optional stream associated with the operation - * @param attr optional attributes to customize operation - * @param deps an optional vector of the events that the operation should depend on - * @return @ref ccl::event an object to track the progress of the operation - */ -event allreduce(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - reduction rtype, - const communicator& comm, - const stream& stream, - const allreduce_attr& attr = default_allreduce_attr, - const vector_class& deps = {}); - -event allreduce(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - reduction rtype, - const communicator& comm, - const allreduce_attr& attr = default_allreduce_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allreduce(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - reduction rtype, - const communicator& comm, - const stream& stream, - const allreduce_attr& attr = default_allreduce_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allreduce(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - reduction rtype, - const communicator& comm, - const allreduce_attr& attr = default_allreduce_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allreduce(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - reduction rtype, - const communicator& comm, - const stream& stream, - const allreduce_attr& attr = default_allreduce_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event allreduce(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - reduction rtype, - const communicator& comm, - const allreduce_attr& attr = default_allreduce_attr, - const vector_class& deps = {}); - -/** - * Alltoall is a collective communication operation in which each rank - * sends distinct equal-sized blocks of data to each rank. - * The j-th block of @c send_buf sent from the i-th rank is received by the j-th rank - * and is placed in the i-th block of @c recvbuf. - */ - -/** - * @param send_buf the buffer with @c count elements of @c dtype that stores local data to be sent - * @param recv_buf [out] the buffer to store received result, should be large enough - * to hold values from all ranks, i.e. at least @c comm_size * @c count - * @param send_bufs array of buffers with local data to be sent, one buffer per each rank - * @param recv_bufs [out] array of buffers to store received result, one buffer per each rank - * @param count the number of elements of type @c dtype to be send to or to received from each rank - * @param dtype the datatype of elements in @c send_buf and @c recv_buf - * @param comm the communicator for which the operation will be performed - * @param stream an optional stream associated with the operation - * @param attr optional attributes to customize operation - * @param deps an optional vector of the events that the operation should depend on - * @return @ref ccl::event an object to track the progress of the operation - */ -event alltoall(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - const communicator& comm, - const stream& stream, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -event alltoall(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - const communicator& comm, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -event alltoall(const vector_class& send_buf, - const vector_class& recv_buf, - size_t count, - datatype dtype, - const communicator& comm, - const stream& stream, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -event alltoall(const vector_class& send_buf, - const vector_class& recv_buf, - size_t count, - datatype dtype, - const communicator& comm, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoall(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - const communicator& comm, - const stream& stream, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoall(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - const communicator& comm, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoall(const vector_class& send_buf, - const vector_class& recv_buf, - size_t count, - const communicator& comm, - const stream& stream, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoall(const vector_class& send_buf, - const vector_class& recv_buf, - size_t count, - const communicator& comm, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoall(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - const communicator& comm, - const stream& stream, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoall(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - const communicator& comm, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoall(const vector_class>& send_buf, - const vector_class>& recv_buf, - size_t count, - const communicator& comm, - const stream& stream, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoall(const vector_class>& send_buf, - const vector_class>& recv_buf, - size_t count, - const communicator& comm, - const alltoall_attr& attr = default_alltoall_attr, - const vector_class& deps = {}); - -/** - * Alltoallv is a collective communication operation in which each rank - * sends distinct blocks of data to each rank. Block sizes may differ. - * The j-th block of @c send_buf sent from the i-th rank is received by the j-th rank - * and is placed in the i-th block of @c recvbuf. - */ - -/** - * @param send_buf the buffer with elements of @c dtype that stores local blocks to be sent to each rank - * @param send_bufs array of buffers to store send blocks, one buffer per each rank - * @param recv_buf [out] the buffer to store received result, should be large enough to hold blocks from all ranks - * @param recv_bufs [out] array of buffers to store receive blocks, one buffer per each rank - * @param send_counts array with the number of elements of type @c dtype in send blocks for each rank - * @param recv_counts array with the number of elements of type @c dtype in receive blocks from each rank - * @param dtype the datatype of elements in @c send_buf and @c recv_buf - * @param comm the communicator for which the operation will be performed - * @param stream an optional stream associated with the operation - * @param attr optional attributes to customize operation - * @param deps an optional vector of the events that the operation should depend on - * @return @ref ccl::event an object to track the progress of the operation - */ -event alltoallv(const void* send_buf, - const vector_class& send_counts, - void* recv_buf, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const stream& stream, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -event alltoallv(const void* send_buf, - const vector_class& send_counts, - void* recv_buf, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -event alltoallv(const vector_class& send_bufs, - const vector_class& send_counts, - const vector_class& recv_bufs, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const stream& stream, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -event alltoallv(const vector_class& send_bufs, - const vector_class& send_counts, - const vector_class& recv_bufs, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoallv(const BufferType* send_buf, - const vector_class& send_counts, - BufferType* recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const stream& stream, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoallv(const BufferType* send_buf, - const vector_class& send_counts, - BufferType* recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoallv(const vector_class& send_bufs, - const vector_class& send_counts, - const vector_class& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const stream& stream, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoallv(const vector_class& send_bufs, - const vector_class& send_counts, - const vector_class& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoallv(const BufferObjectType& send_buf, - const vector_class& send_counts, - BufferObjectType& recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const stream& stream, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoallv(const BufferObjectType& send_buf, - const vector_class& send_counts, - BufferObjectType& recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoallv(const vector_class>& send_bufs, - const vector_class& send_counts, - const vector_class>& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const stream& stream, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event alltoallv(const vector_class>& send_bufs, - const vector_class& send_counts, - const vector_class>& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const alltoallv_attr& attr = default_alltoallv_attr, - const vector_class& deps = {}); - -/** - * Barrier synchronization is performed across all ranks of the communicator - * and it is completed only after all the ranks in the communicator have called it. - */ - -/** - * @param comm the communicator for which the operation will be performed - * @param stream an optional stream associated with the operation - * @param attr optional attributes to customize operation - * @param deps an optional vector of the events that the operation should depend on - * @return @ref ccl::event an object to track the progress of the operation - */ -event barrier(const communicator& comm, - const stream& stream, - const barrier_attr& attr = default_barrier_attr, - const vector_class& deps = {}); - -event barrier(const communicator& comm, - const barrier_attr& attr = default_barrier_attr, - const vector_class& deps = {}); - -/** - * Broadcast is a collective communication operation that broadcasts data - * from one rank of communicator (denoted as root) to all other ranks. - */ - -/** - * @param buf [in,out] the buffer with @c count elements of @c dtype - * serves as send buffer for root and as receive buffer for other ranks - * @param count the number of elements of type @c dtype in @c buf - * @param dtype the datatype of elements in @c buf - * @param root the rank that broadcasts @c buf - * @param comm the communicator for which the operation will be performed - * @param stream an optional stream associated with the operation - * @param attr optional attributes to customize operation - * @param deps an optional vector of the events that the operation should depend on - * @return @ref ccl::event an object to track the progress of the operation - */ -event broadcast(void* buf, - size_t count, - datatype dtype, - size_t root, - const communicator& comm, - const stream& stream, - const broadcast_attr& attr = default_broadcast_attr, - const vector_class& deps = {}); - -event broadcast(void* buf, - size_t count, - datatype dtype, - size_t root, - const communicator& comm, - const broadcast_attr& attr = default_broadcast_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event broadcast(BufferType* buf, - size_t count, - size_t root, - const communicator& comm, - const stream& stream, - const broadcast_attr& attr = default_broadcast_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event broadcast(BufferType* buf, - size_t count, - size_t root, - const communicator& comm, - const broadcast_attr& attr = default_broadcast_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event broadcast(BufferObjectType& buf, - size_t count, - size_t root, - const communicator& comm, - const stream& stream, - const broadcast_attr& attr = default_broadcast_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event broadcast(BufferObjectType& buf, - size_t count, - size_t root, - const communicator& comm, - const broadcast_attr& attr = default_broadcast_attr, - const vector_class& deps = {}); - -/** - * Reduce is a collective communication operation that performs the global reduction operation - * on values from all ranks of the communicator and returns the result to the root rank. - */ - -/** - * @param send_buf the buffer with @c count elements of @c dtype that stores local data to be reduced - * @param recv_buf [out] the buffer to store reduced result, must have the same dimension as @c send_buf. - * Used by the @c root rank only, ignored by other ranks. - * @param count the number of elements of type @c dtype in @c send_buf and @c recv_buf - * @param dtype the datatype of elements in @c send_buf and @c recv_buf - * @param rtype the type of the reduction operation to be applied - * @param root the rank that gets the result of reduction - * @param comm the communicator for which the operation will be performed - * @param stream an optional stream associated with the operation - * @param attr optional attributes to customize operation - * @param deps an optional vector of the events that the operation should depend on - * @return @ref ccl::event an object to track the progress of the operation - */ -event reduce(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - reduction rtype, - size_t root, - const communicator& comm, - const stream& stream, - const reduce_attr& attr = default_reduce_attr, - const vector_class& deps = {}); - -event reduce(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - reduction rtype, - size_t root, - const communicator& comm, - const reduce_attr& attr = default_reduce_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event reduce(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - reduction rtype, - size_t root, - const communicator& comm, - const stream& stream, - const reduce_attr& attr = default_reduce_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event reduce(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - reduction rtype, - size_t root, - const communicator& comm, - const reduce_attr& attr = default_reduce_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event reduce(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - reduction rtype, - size_t root, - const communicator& comm, - const stream& stream, - const reduce_attr& attr = default_reduce_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event reduce(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - reduction rtype, - size_t root, - const communicator& comm, - const reduce_attr& attr = default_reduce_attr, - const vector_class& deps = {}); - -/** - * Reduce-scatter is a collective communication operation that performs the global reduction operation - * on values from all ranks of the communicator and scatters the result in blocks back to all ranks. - */ - -/** - * @param send_buf the buffer with @c comm_size * @c count elements of @c dtype that stores local data to be reduced - * @param recv_buf [out] the buffer to store result block containing @c recv_count elements of type @c dtype - * @param recv_count the number of elements of type @c dtype in receive block - * @param dtype the datatype of elements in @c send_buf and @c recv_buf - * @param rtype the type of the reduction operation to be applied - * @param comm the communicator for which the operation will be performed - * @param stream an optional stream associated with the operation - * @param attr optional attributes to customize operation - * @param deps an optional vector of the events that the operation should depend on - * @return @ref ccl::event an object to track the progress of the operation - */ -event reduce_scatter(const void* send_buf, - void* recv_buf, - size_t recv_count, - datatype dtype, - reduction rtype, - const communicator& comm, - const stream& stream, - const reduce_scatter_attr& attr = default_reduce_scatter_attr, - const vector_class& deps = {}); - -event reduce_scatter(const void* send_buf, - void* recv_buf, - size_t recv_count, - datatype dtype, - reduction rtype, - const communicator& comm, - const reduce_scatter_attr& attr = default_reduce_scatter_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event reduce_scatter(const BufferType* send_buf, - BufferType* recv_buf, - size_t recv_count, - reduction rtype, - const communicator& comm, - const stream& stream, - const reduce_scatter_attr& attr = default_reduce_scatter_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event reduce_scatter(const BufferType* send_buf, - BufferType* recv_buf, - size_t recv_count, - reduction rtype, - const communicator& comm, - const reduce_scatter_attr& attr = default_reduce_scatter_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event reduce_scatter(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t recv_count, - reduction rtype, - const communicator& comm, - const stream& stream, - const reduce_scatter_attr& attr = default_reduce_scatter_attr, - const vector_class& deps = {}); - -/* Type safety version */ -template (), event>::type> -event reduce_scatter(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t recv_count, - reduction rtype, - const communicator& comm, - const reduce_scatter_attr& attr = default_reduce_scatter_attr, - const vector_class& deps = {}); - -namespace preview { - -/** - * Sparse allreduce is a collective communication operation that makes global reduction operation - * on sparse buffers from all ranks of communicator and distributes result back to all ranks. - * Sparse buffers are defined by separate index and value buffers. - */ - -/** - * @param send_ind_buf the buffer of indices with @c send_ind_count elements of type @c ind_dtype - * @param send_ind_count the number of elements of type @c ind_type @c send_ind_buf - * @param send_val_buf the buffer of values with @c send_val_count elements of type @c val_dtype - * @param send_val_count the number of elements of type @c val_type @c send_val_buf - * @param recv_ind_buf [out] the buffer to store reduced indices, unused - * @param recv_ind_count [out] the number of elements in @c recv_ind_buf, unused - * @param recv_val_buf [out] the buffer to store reduced values, unused - * @param recv_val_count [out] the number of elements in @c recv_val_buf, unused - * @param ind_dtype the datatype of elements in @c send_ind_buf and @c recv_ind_buf - * @param val_dtype the datatype of elements in @c send_val_buf and @c recv_val_buf - * @param rtype the type of the reduction operation to be applied - * @param comm the communicator for which the operation will be performed - * @param stream an optional stream associated with the operation - * @param attr optional attributes to customize operation - * @param deps an optional vector of the events that the operation should depend on - * @return @ref ccl::event an object to track the progress of the operation - */ - -ccl::event sparse_allreduce( - const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype ind_dtype, - ccl::datatype val_dtype, - ccl::reduction rtype, - const ccl::communicator& comm, - const ccl::stream& stream, - const ccl::sparse_allreduce_attr& attr = ccl::default_sparse_allreduce_attr, - const ccl::vector_class& deps = {}); - -ccl::event sparse_allreduce( - const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype ind_dtype, - ccl::datatype val_dtype, - ccl::reduction rtype, - const ccl::communicator& comm, - const ccl::sparse_allreduce_attr& attr = ccl::default_sparse_allreduce_attr, - const ccl::vector_class& deps = {}); - -/* Type safety version */ -template (), - ccl::event>::type> -ccl::event sparse_allreduce( - const IndexBufferType* send_ind_buf, - size_t send_ind_count, - const ValueBufferType* send_val_buf, - size_t send_val_count, - IndexBufferType* recv_ind_buf, - size_t recv_ind_count, - ValueBufferType* recv_val_buf, - size_t recv_val_count, - ccl::reduction rtype, - const ccl::communicator& comm, - const ccl::stream& stream, - const ccl::sparse_allreduce_attr& attr = default_sparse_allreduce_attr, - const ccl::vector_class& deps = {}); - -/* Type safety version */ -template (), - ccl::event>::type> -ccl::event sparse_allreduce( - const IndexBufferType* send_ind_buf, - size_t send_ind_count, - const ValueBufferType* send_val_buf, - size_t send_val_count, - IndexBufferType* recv_ind_buf, - size_t recv_ind_count, - ValueBufferType* recv_val_buf, - size_t recv_val_count, - ccl::reduction rtype, - const ccl::communicator& comm, - const ccl::sparse_allreduce_attr& attr = default_sparse_allreduce_attr, - const ccl::vector_class& deps = {}); - -} // namespace preview - -} // namespace ccl +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace v1 { + +/******************** INIT ********************/ + +/** + * Creates an attribute object, which may used to control init operation + * @return an attribute object + */ +template +init_attr create_init_attr(attr_value_pair_t&&... avps) { + return detail::environment::create_init_attr(std::forward(avps)...); +} + +/** + * Initializes the library. Optional for invocation. + * @param attr optional init attributes + */ +void init(const init_attr& attr = default_init_attr); + +/** + * Retrieves the library version + */ +library_version get_library_version(); + +/******************** DATATYPE ********************/ + +/** + * Creates an attribute object, which may used to register custom datatype + * @return an attribute object + */ +template +datatype_attr create_datatype_attr(attr_value_pair_t&&... avps) { + return detail::environment::create_datatype_attr(std::forward(avps)...); +} + +/** + * Registers custom datatype to be used in communication operations + * @param attr datatype attributes + * @return datatype handle + */ +datatype register_datatype(const datatype_attr& attr); + +/** + * Deregisters custom datatype + * @param dtype custom datatype handle + */ +void deregister_datatype(datatype dtype); + +/** + * Retrieves a datatype size in bytes + * @param dtype datatype handle + * @return datatype size + */ +size_t get_datatype_size(datatype dtype); + +/******************** KVS ********************/ + +template +kvs_attr create_kvs_attr(attr_value_pair_t&&... avps) { + return detail::environment::create_kvs_attr(std::forward(avps)...); +} + +/** + * Creates a main key-value store. + * It's address should be distributed using out of band communication mechanism + * and be used to create key-value stores on other processes. + * @param attr optional kvs attributes + * @return kvs object + */ +shared_ptr_class create_main_kvs(const kvs_attr& attr = default_kvs_attr); + +/** + * Creates a new key-value store from main kvs address + * @param addr address of main kvs + * @param attr optional kvs attributes + * @return kvs object + */ +shared_ptr_class create_kvs(const kvs::address_type& addr, + const kvs_attr& attr = default_kvs_attr); + +/******************** DEVICE ********************/ + +/** + * Creates a new device from @native_device_type + * @param native_device the existing handle of device + * @return device object + */ +template ()>::type> +device create_device(native_device_type&& native_device) { + return detail::environment::instance().create_device( + std::forward(native_device)); +} + +device create_device(); + +/******************** CONTEXT ********************/ + +/** + * Creates a new context from @native_contex_type + * @param native_context the existing handle of context + * @return context object + */ +template ()>::type> +context create_context(native_context_type&& native_context) { + return detail::environment::instance().create_context( + std::forward(native_context)); +} + +context create_context(); + +/******************** EVENT ********************/ + +/** + * Creates a new event from @native_event_type + * @param native_event the existing event + * @return event object + */ +template ()>::type> +event create_event(event_type& native_event) { + return detail::environment::instance().create_event(native_event); +} + +/******************** STREAM ********************/ + +/** + * Creates a new stream from @native_stream_type + * @param native_stream the existing handle of stream + * @return stream object + */ +template ()>::type> +stream create_stream(native_stream_type& native_stream) { + return detail::environment::instance().create_stream(native_stream); +} + +stream create_stream(); + +/******************** COMMUNICATOR ********************/ + +/** + * Creates an attribute object, which may used to control create communicator operation + * @return an attribute object + */ +template +comm_attr create_comm_attr(attr_value_pair_t&&... avps) { + return detail::environment::create_comm_attr(std::forward(avps)...); +} + +} // namespace v1 + +namespace preview { + +/** + * Creates an attribute object, which may used to control split communicator operation + * @return an attribute object + */ +template +comm_split_attr create_comm_split_attr(attr_value_pair_t&&... avps) { + return detail::environment::create_comm_split_attr(std::forward(avps)...); +} + +} // namespace preview + +namespace v1 { + +/** + * Creates a new communicator with user supplied size, rank and kvs. + * @param size user-supplied total number of ranks + * @param rank user-supplied rank + * @param kvs key-value store for ranks wire-up + * @return communicator + */ +communicator create_communicator(int size, + int rank, + shared_ptr_class kvs, + const comm_attr& attr = default_comm_attr); + +/** + * Creates a new communicators with user supplied size, ranks, local device-rank mapping and kvs. + * @param size user-supplied total number of ranks + * @param device local device + * @param devices user-supplied mapping of local ranks on devices + * @param context context containing the devices + * @param kvs key-value store for ranks wire-up + * @return vector of communicators + */ +template +vector_class create_communicators( + int size, + const vector_class>& devices, + const ContextType& context, + shared_ptr_class kvs, + const comm_attr& attr = default_comm_attr) { + return detail::environment::instance().create_communicators(size, devices, context, kvs, attr); +} + +template +vector_class create_communicators(int size, + const map_class& devices, + const ContextType& context, + shared_ptr_class kvs, + const comm_attr& attr = default_comm_attr) { + return detail::environment::instance().create_communicators(size, devices, context, kvs, attr); +} + +template +communicator create_communicator(int size, + int rank, + DeviceType& device, + const ContextType& context, + shared_ptr_class kvs, + const comm_attr& attr = default_comm_attr) { + auto comms = detail::environment::instance().create_communicators( + size, + ccl::vector_class>{ { rank, device } }, + context, + kvs, + attr); + + if (comms.size() != 1) + throw ccl::exception("unexpected comm vector size"); + + return std::move(comms[0]); +} + +} // namespace v1 + +namespace preview { + +/** + * Splits communicators according to attributes. + * @param attrs split attributes for local communicators + * @return vector of communicators + */ +vector_class split_communicators( + const vector_class>& attrs); + +/** + * Creates a new communicator with externally provided size, rank and kvs. + * Implementation is platform specific and non portable. + * @return communicator + */ +communicator create_communicator(const comm_attr& attr = default_comm_attr); + +/** + * Creates a new communicator with user supplied size and kvs. + * Rank will be assigned automatically. + * @param size user-supplied total number of ranks + * @param kvs key-value store for ranks wire-up + * @return communicator + */ +communicator create_communicator(int size, + shared_ptr_class kvs, + const comm_attr& attr = default_comm_attr); + +/** + * Creates a new communicators with user supplied size, local devices and kvs. + * Ranks will be assigned automatically. + * @param size user-supplied total number of ranks + * @param devices user-supplied device objects for local ranks + * @param context context containing the devices + * @param kvs key-value store for ranks wire-up + * @return vector of communicators + */ +template +vector_class create_communicators(int size, + const vector_class& devices, + const ContextType& context, + shared_ptr_class kvs, + const comm_attr& attr = default_comm_attr) { + return detail::environment::instance().create_communicators(size, devices, context, kvs, attr); +} + +} // namespace preview + +namespace v1 { + +/******************** OPERATION ********************/ + +/** + * Creates an attribute object, which may used to customize communication operation + * @return an attribute object + */ +template +coll_attribute_type CCL_API create_operation_attr(attr_value_pair_t&&... avps) { + return detail::environment::create_operation_attr( + std::forward(avps)...); +} + +/** + * Allgatherv is a collective communication operation that collects data + * from all the ranks within a communicator into a single buffer. + * Different ranks may contribute segments of different sizes. + * The resulting data in the output buffer must be the same for each rank. + */ + +/** + * @param send_buf the buffer with @c send_count elements of @c dtype that stores local data to be gathered + * @param send_count the number of elements of type @c dtype in @c send_buf + * @param recv_buf [out] the buffer to store gathered result, should be large enough to hold values from all ranks + * @param recv_bufs [out] array of buffers to store gathered result, one buffer per each rank + * @param recv_counts array with the number of elements of type @c dtype to be received from each rank + * @param dtype the datatype of elements in @c send_buf and @c recv_buf + * @param comm the communicator for which the operation will be performed + * @param stream a stream associated with the operation + * @param attr optional attributes to customize operation + * @param deps an optional vector of the events that the operation should depend on + * @return @ref ccl::event an object to track the progress of the operation + */ +event allgatherv(const void* send_buf, + size_t send_count, + void* recv_buf, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const stream& stream, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +event allgatherv(const void* send_buf, + size_t send_count, + void* recv_buf, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +event allgatherv(const void* send_buf, + size_t send_count, + const vector_class& recv_bufs, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const stream& stream, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +event allgatherv(const void* send_buf, + size_t send_count, + const vector_class& recv_bufs, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allgatherv(const BufferType* send_buf, + size_t send_count, + BufferType* recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const stream& stream, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allgatherv(const BufferType* send_buf, + size_t send_count, + BufferType* recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allgatherv(const BufferType* send_buf, + size_t send_count, + vector_class& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const stream& stream, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allgatherv(const BufferType* send_buf, + size_t send_count, + vector_class& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allgatherv(const BufferObjectType& send_buf, + size_t send_count, + BufferObjectType& recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const stream& stream, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allgatherv(const BufferObjectType& send_buf, + size_t send_count, + BufferObjectType& recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allgatherv(const BufferObjectType& send_buf, + size_t send_count, + vector_class>& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const stream& stream, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allgatherv(const BufferObjectType& send_buf, + size_t send_count, + vector_class>& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const allgatherv_attr& attr = default_allgatherv_attr, + const vector_class& deps = {}); + +/** + * Allreduce is a collective communication operation that performs the global reduction operation + * on values from all ranks of communicator and distributes the result back to all ranks. + */ + +/** + * @param send_buf the buffer with @c count elements of @c dtype that stores local data to be reduced + * @param recv_buf [out] the buffer to store reduced result, must have the same dimension as @c send_buf + * @param count the number of elements of type @c dtype in @c send_buf and @c recv_buf + * @param dtype the datatype of elements in @c send_buf and @c recv_buf + * @param rtype the type of the reduction operation to be applied + * @param comm the communicator for which the operation will be performed + * @param stream a stream associated with the operation + * @param attr optional attributes to customize operation + * @param deps an optional vector of the events that the operation should depend on + * @return @ref ccl::event an object to track the progress of the operation + */ +event allreduce(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + reduction rtype, + const communicator& comm, + const stream& stream, + const allreduce_attr& attr = default_allreduce_attr, + const vector_class& deps = {}); + +event allreduce(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + reduction rtype, + const communicator& comm, + const allreduce_attr& attr = default_allreduce_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allreduce(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + reduction rtype, + const communicator& comm, + const stream& stream, + const allreduce_attr& attr = default_allreduce_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allreduce(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + reduction rtype, + const communicator& comm, + const allreduce_attr& attr = default_allreduce_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allreduce(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + reduction rtype, + const communicator& comm, + const stream& stream, + const allreduce_attr& attr = default_allreduce_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event allreduce(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + reduction rtype, + const communicator& comm, + const allreduce_attr& attr = default_allreduce_attr, + const vector_class& deps = {}); + +/** + * Alltoall is a collective communication operation in which each rank + * sends distinct equal-sized blocks of data to each rank. + * The j-th block of @c send_buf sent from the i-th rank is received by the j-th rank + * and is placed in the i-th block of @c recvbuf. + */ + +/** + * @param send_buf the buffer with @c count elements of @c dtype that stores local data to be sent + * @param recv_buf [out] the buffer to store received result, should be large enough + * to hold values from all ranks, i.e. at least @c comm_size * @c count + * @param send_bufs array of buffers with local data to be sent, one buffer per each rank + * @param recv_bufs [out] array of buffers to store received result, one buffer per each rank + * @param count the number of elements of type @c dtype to be send to or to received from each rank + * @param dtype the datatype of elements in @c send_buf and @c recv_buf + * @param comm the communicator for which the operation will be performed + * @param stream a stream associated with the operation + * @param attr optional attributes to customize operation + * @param deps an optional vector of the events that the operation should depend on + * @return @ref ccl::event an object to track the progress of the operation + */ +event alltoall(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + const communicator& comm, + const stream& stream, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +event alltoall(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + const communicator& comm, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +event alltoall(const vector_class& send_buf, + const vector_class& recv_buf, + size_t count, + datatype dtype, + const communicator& comm, + const stream& stream, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +event alltoall(const vector_class& send_buf, + const vector_class& recv_buf, + size_t count, + datatype dtype, + const communicator& comm, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoall(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + const communicator& comm, + const stream& stream, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoall(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + const communicator& comm, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoall(const vector_class& send_buf, + const vector_class& recv_buf, + size_t count, + const communicator& comm, + const stream& stream, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoall(const vector_class& send_buf, + const vector_class& recv_buf, + size_t count, + const communicator& comm, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoall(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + const communicator& comm, + const stream& stream, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoall(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + const communicator& comm, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoall(const vector_class>& send_buf, + const vector_class>& recv_buf, + size_t count, + const communicator& comm, + const stream& stream, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoall(const vector_class>& send_buf, + const vector_class>& recv_buf, + size_t count, + const communicator& comm, + const alltoall_attr& attr = default_alltoall_attr, + const vector_class& deps = {}); + +/** + * Alltoallv is a collective communication operation in which each rank + * sends distinct blocks of data to each rank. Block sizes may differ. + * The j-th block of @c send_buf sent from the i-th rank is received by the j-th rank + * and is placed in the i-th block of @c recvbuf. + */ + +/** + * @param send_buf the buffer with elements of @c dtype that stores local blocks to be sent to each rank + * @param send_bufs array of buffers to store send blocks, one buffer per each rank + * @param recv_buf [out] the buffer to store received result, should be large enough to hold blocks from all ranks + * @param recv_bufs [out] array of buffers to store receive blocks, one buffer per each rank + * @param send_counts array with the number of elements of type @c dtype in send blocks for each rank + * @param recv_counts array with the number of elements of type @c dtype in receive blocks from each rank + * @param dtype the datatype of elements in @c send_buf and @c recv_buf + * @param comm the communicator for which the operation will be performed + * @param stream a stream associated with the operation + * @param attr optional attributes to customize operation + * @param deps an optional vector of the events that the operation should depend on + * @return @ref ccl::event an object to track the progress of the operation + */ +event alltoallv(const void* send_buf, + const vector_class& send_counts, + void* recv_buf, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const stream& stream, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +event alltoallv(const void* send_buf, + const vector_class& send_counts, + void* recv_buf, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +event alltoallv(const vector_class& send_bufs, + const vector_class& send_counts, + const vector_class& recv_bufs, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const stream& stream, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +event alltoallv(const vector_class& send_bufs, + const vector_class& send_counts, + const vector_class& recv_bufs, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoallv(const BufferType* send_buf, + const vector_class& send_counts, + BufferType* recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const stream& stream, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoallv(const BufferType* send_buf, + const vector_class& send_counts, + BufferType* recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoallv(const vector_class& send_bufs, + const vector_class& send_counts, + const vector_class& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const stream& stream, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoallv(const vector_class& send_bufs, + const vector_class& send_counts, + const vector_class& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoallv(const BufferObjectType& send_buf, + const vector_class& send_counts, + BufferObjectType& recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const stream& stream, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoallv(const BufferObjectType& send_buf, + const vector_class& send_counts, + BufferObjectType& recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoallv(const vector_class>& send_bufs, + const vector_class& send_counts, + const vector_class>& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const stream& stream, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event alltoallv(const vector_class>& send_bufs, + const vector_class& send_counts, + const vector_class>& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const alltoallv_attr& attr = default_alltoallv_attr, + const vector_class& deps = {}); + +/** + * Barrier synchronization is performed across all ranks of the communicator + * and it is completed only after all the ranks in the communicator have called it. + */ + +/** + * @param comm the communicator for which the operation will be performed + * @param stream a stream associated with the operation + * @param attr optional attributes to customize operation + * @param deps an optional vector of the events that the operation should depend on + * @return @ref ccl::event an object to track the progress of the operation + */ +event barrier(const communicator& comm, + const stream& stream, + const barrier_attr& attr = default_barrier_attr, + const vector_class& deps = {}); + +event barrier(const communicator& comm, + const barrier_attr& attr = default_barrier_attr, + const vector_class& deps = {}); + +/** + * Broadcast is a collective communication operation that broadcasts data + * from one rank of communicator (denoted as root) to all other ranks. + */ + +/** + * @param buf [in,out] the buffer with @c count elements of @c dtype + * serves as send buffer for root and as receive buffer for other ranks + * @param count the number of elements of type @c dtype in @c buf + * @param dtype the datatype of elements in @c buf + * @param root the rank that broadcasts @c buf + * @param comm the communicator for which the operation will be performed + * @param stream a stream associated with the operation + * @param attr optional attributes to customize operation + * @param deps an optional vector of the events that the operation should depend on + * @return @ref ccl::event an object to track the progress of the operation + */ +event broadcast(void* buf, + size_t count, + datatype dtype, + int root, + const communicator& comm, + const stream& stream, + const broadcast_attr& attr = default_broadcast_attr, + const vector_class& deps = {}); + +event broadcast(void* buf, + size_t count, + datatype dtype, + int root, + const communicator& comm, + const broadcast_attr& attr = default_broadcast_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event broadcast(BufferType* buf, + size_t count, + int root, + const communicator& comm, + const stream& stream, + const broadcast_attr& attr = default_broadcast_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event broadcast(BufferType* buf, + size_t count, + int root, + const communicator& comm, + const broadcast_attr& attr = default_broadcast_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event broadcast(BufferObjectType& buf, + size_t count, + int root, + const communicator& comm, + const stream& stream, + const broadcast_attr& attr = default_broadcast_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event broadcast(BufferObjectType& buf, + size_t count, + int root, + const communicator& comm, + const broadcast_attr& attr = default_broadcast_attr, + const vector_class& deps = {}); + +/** + * Reduce is a collective communication operation that performs the global reduction operation + * on values from all ranks of the communicator and returns the result to the root rank. + */ + +/** + * @param send_buf the buffer with @c count elements of @c dtype that stores local data to be reduced + * @param recv_buf [out] the buffer to store reduced result, must have the same dimension as @c send_buf. + * Used by the @c root rank only, ignored by other ranks. + * @param count the number of elements of type @c dtype in @c send_buf and @c recv_buf + * @param dtype the datatype of elements in @c send_buf and @c recv_buf + * @param rtype the type of the reduction operation to be applied + * @param root the rank that gets the result of reduction + * @param comm the communicator for which the operation will be performed + * @param stream a stream associated with the operation + * @param attr optional attributes to customize operation + * @param deps an optional vector of the events that the operation should depend on + * @return @ref ccl::event an object to track the progress of the operation + */ +event reduce(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + reduction rtype, + int root, + const communicator& comm, + const stream& stream, + const reduce_attr& attr = default_reduce_attr, + const vector_class& deps = {}); + +event reduce(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + reduction rtype, + int root, + const communicator& comm, + const reduce_attr& attr = default_reduce_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event reduce(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + reduction rtype, + int root, + const communicator& comm, + const stream& stream, + const reduce_attr& attr = default_reduce_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event reduce(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + reduction rtype, + int root, + const communicator& comm, + const reduce_attr& attr = default_reduce_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event reduce(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + reduction rtype, + int root, + const communicator& comm, + const stream& stream, + const reduce_attr& attr = default_reduce_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event reduce(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + reduction rtype, + int root, + const communicator& comm, + const reduce_attr& attr = default_reduce_attr, + const vector_class& deps = {}); + +/** + * Reduce-scatter is a collective communication operation that performs the global reduction operation + * on values from all ranks of the communicator and scatters the result in blocks back to all ranks. + */ + +/** + * @param send_buf the buffer with @c comm_size * @c count elements of @c dtype that stores local data to be reduced + * @param recv_buf [out] the buffer to store result block containing @c recv_count elements of type @c dtype + * @param recv_count the number of elements of type @c dtype in receive block + * @param dtype the datatype of elements in @c send_buf and @c recv_buf + * @param rtype the type of the reduction operation to be applied + * @param comm the communicator for which the operation will be performed + * @param stream a stream associated with the operation + * @param attr optional attributes to customize operation + * @param deps an optional vector of the events that the operation should depend on + * @return @ref ccl::event an object to track the progress of the operation + */ +event reduce_scatter(const void* send_buf, + void* recv_buf, + size_t recv_count, + datatype dtype, + reduction rtype, + const communicator& comm, + const stream& stream, + const reduce_scatter_attr& attr = default_reduce_scatter_attr, + const vector_class& deps = {}); + +event reduce_scatter(const void* send_buf, + void* recv_buf, + size_t recv_count, + datatype dtype, + reduction rtype, + const communicator& comm, + const reduce_scatter_attr& attr = default_reduce_scatter_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event reduce_scatter(const BufferType* send_buf, + BufferType* recv_buf, + size_t recv_count, + reduction rtype, + const communicator& comm, + const stream& stream, + const reduce_scatter_attr& attr = default_reduce_scatter_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event reduce_scatter(const BufferType* send_buf, + BufferType* recv_buf, + size_t recv_count, + reduction rtype, + const communicator& comm, + const reduce_scatter_attr& attr = default_reduce_scatter_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event reduce_scatter(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t recv_count, + reduction rtype, + const communicator& comm, + const stream& stream, + const reduce_scatter_attr& attr = default_reduce_scatter_attr, + const vector_class& deps = {}); + +/* Type safety version */ +template (), event>::type> +event reduce_scatter(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t recv_count, + reduction rtype, + const communicator& comm, + const reduce_scatter_attr& attr = default_reduce_scatter_attr, + const vector_class& deps = {}); + +} // namespace v1 + +namespace preview { + +/** + * Sparse allreduce is a collective communication operation that makes global reduction operation + * on sparse buffers from all ranks of communicator and distributes result back to all ranks. + * Sparse buffers are defined by separate index and value buffers. + */ + +/** + * @param send_ind_buf the buffer of indices with @c send_ind_count elements of type @c ind_dtype + * @param send_ind_count the number of elements of type @c ind_type @c send_ind_buf + * @param send_val_buf the buffer of values with @c send_val_count elements of type @c val_dtype + * @param send_val_count the number of elements of type @c val_type @c send_val_buf + * @param recv_ind_buf [out] the buffer to store reduced indices, unused + * @param recv_ind_count [out] the number of elements in @c recv_ind_buf, unused + * @param recv_val_buf [out] the buffer to store reduced values, unused + * @param recv_val_count [out] the number of elements in @c recv_val_buf, unused + * @param ind_dtype the datatype of elements in @c send_ind_buf and @c recv_ind_buf + * @param val_dtype the datatype of elements in @c send_val_buf and @c recv_val_buf + * @param rtype the type of the reduction operation to be applied + * @param comm the communicator for which the operation will be performed + * @param stream a stream associated with the operation + * @param attr optional attributes to customize operation + * @param deps an optional vector of the events that the operation should depend on + * @return @ref ccl::event an object to track the progress of the operation + */ + +ccl::event sparse_allreduce( + const void* send_ind_buf, + size_t send_ind_count, + const void* send_val_buf, + size_t send_val_count, + void* recv_ind_buf, + size_t recv_ind_count, + void* recv_val_buf, + size_t recv_val_count, + ccl::datatype ind_dtype, + ccl::datatype val_dtype, + ccl::reduction rtype, + const ccl::communicator& comm, + const ccl::stream& stream, + const ccl::sparse_allreduce_attr& attr = ccl::default_sparse_allreduce_attr, + const ccl::vector_class& deps = {}); + +ccl::event sparse_allreduce( + const void* send_ind_buf, + size_t send_ind_count, + const void* send_val_buf, + size_t send_val_count, + void* recv_ind_buf, + size_t recv_ind_count, + void* recv_val_buf, + size_t recv_val_count, + ccl::datatype ind_dtype, + ccl::datatype val_dtype, + ccl::reduction rtype, + const ccl::communicator& comm, + const ccl::sparse_allreduce_attr& attr = ccl::default_sparse_allreduce_attr, + const ccl::vector_class& deps = {}); + +/* Type safety version */ +template (), + ccl::event>::type> +ccl::event sparse_allreduce( + const IndexBufferType* send_ind_buf, + size_t send_ind_count, + const ValueBufferType* send_val_buf, + size_t send_val_count, + IndexBufferType* recv_ind_buf, + size_t recv_ind_count, + ValueBufferType* recv_val_buf, + size_t recv_val_count, + ccl::reduction rtype, + const ccl::communicator& comm, + const ccl::stream& stream, + const ccl::sparse_allreduce_attr& attr = ccl::default_sparse_allreduce_attr, + const ccl::vector_class& deps = {}); + +/* Type safety version */ +template (), + ccl::event>::type> +ccl::event sparse_allreduce( + const IndexBufferType* send_ind_buf, + size_t send_ind_count, + const ValueBufferType* send_val_buf, + size_t send_val_count, + IndexBufferType* recv_ind_buf, + size_t recv_ind_count, + ValueBufferType* recv_val_buf, + size_t recv_val_count, + ccl::reduction rtype, + const ccl::communicator& comm, + const ccl::sparse_allreduce_attr& attr = ccl::default_sparse_allreduce_attr, + const ccl::vector_class& deps = {}); + +} // namespace preview + +using namespace v1; + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_environment.hpp b/include/oneapi/ccl/ccl_environment.hpp deleted file mode 100644 index 657625550..000000000 --- a/include/oneapi/ccl/ccl_environment.hpp +++ /dev/null @@ -1,222 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include -#include -#include -#include - -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_type_traits.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_coll_attr.hpp" - -#include "oneapi/ccl/ccl_comm_split_attr_ids.hpp" -#include "oneapi/ccl/ccl_comm_split_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_comm_split_attr.hpp" - -#include "oneapi/ccl/ccl_context_attr_ids.hpp" -#include "oneapi/ccl/ccl_context_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_context.hpp" - -#include "oneapi/ccl/ccl_datatype_attr_ids.hpp" -#include "oneapi/ccl/ccl_datatype_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_datatype_attr.hpp" - -#include "oneapi/ccl/ccl_device_attr_ids.hpp" -#include "oneapi/ccl/ccl_device_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_device.hpp" - -#include "oneapi/ccl/ccl_kvs.hpp" - -#include "oneapi/ccl/ccl_event.hpp" - -#include "oneapi/ccl/ccl_stream_attr_ids.hpp" -#include "oneapi/ccl/ccl_stream_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_stream.hpp" - -#include "oneapi/ccl/ccl_communicator.hpp" - -#include "oneapi/ccl/ccl_exception.hpp" - -namespace ccl { - -/** - * CCL environment singleton - */ -class environment { -public: - ~environment(); - - /** - * Retrieves the unique environment object - * and makes the first-time initialization of CCL library - */ - static environment& instance(); - - ccl::library_version get_library_version() const; - - template - datatype_attr create_datatype_attr(attr_value_pair_t&&... avps) const { - static_assert(sizeof...(avps) > 0, "At least one argument must be specified"); - auto attr = create_postponed_api_type(); - int expander[]{ (attr.template set(avps.val()), 0)... }; - (void)expander; - return attr; - } - - ccl::datatype register_datatype(const ccl::datatype_attr& attr); - void deregister_datatype(ccl::datatype dtype); - size_t get_datatype_size(ccl::datatype dtype) const; - - shared_ptr_class create_main_kvs() const; - shared_ptr_class create_kvs(const kvs::address_type& addr) const; - - device create_device(empty_t empty) const; - - template ()>::type> - device create_device(native_device_type&& native_device) const; - - template - device create_device_from_attr(typename unified_device_type::ccl_native_t dev, - attr_value_pair_t&&... avps) const { - device str = create_postponed_api_type(dev); - int expander[]{ (str.template set(avps.val()), 0)... }; - (void)expander; - str.build_from_params(); - return str; - } - - context create_context(empty_t empty) const; - - template ()>::type> - context create_context(native_device_contex_type&& native_device_context) const; - - template - context create_context_from_attr(typename unified_device_context_type::ccl_native_t ctx, - attr_value_pair_t&&... avps) const { - context str = create_postponed_api_type(ctx); - int expander[]{ (str.template set(avps.val()), 0)... }; - (void)expander; - str.build_from_params(); - return str; - } - - template - coll_attribute_type create_operation_attr(attr_value_pair_t&&... avps) const { - auto op_attr = create_postponed_api_type(); - int expander[]{ (op_attr.template set(avps.val()), 0)... }; - (void)expander; - return op_attr; - } - - template ()>::type> - event create_event(event_type& native_event) { - return event::create_from_native(native_event); - } - - template ()>::type> - stream create_stream(native_stream_type& native_stream); - - template ()>::type> - stream create_stream(native_stream_type& native_stream, native_context_type& native_ctx); - - template - stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, - attr_value_pair_t&&... avps) { - stream str = create_postponed_api_type(device); - int expander[]{ (str.template set(avps.val()), 0)... }; - (void)expander; - str.build_from_params(); - return str; - } - - template - stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, - typename unified_device_context_type::ccl_native_t context, - attr_value_pair_t&&... avps) { - stream str = create_postponed_api_type(device, context); - int expander[]{ (str.template set(avps.val()), 0)... }; - (void)expander; - str.build_from_params(); - return str; - } - - -#ifdef CCL_ENABLE_SYCL - communicator create_single_device_communicator( - size_t comm_size, - size_t rank, - const cl::sycl::device& device, - const cl::sycl::context& context, - shared_ptr_class kvs) const; -#endif - - template - comm_split_attr create_comm_split_attr(attr_value_pair_t&&... avps) const { - auto split_attr = create_postponed_api_type(); - int expander[]{ (split_attr.template set(avps.val()), 0)... }; - (void)expander; - return split_attr; - } - - communicator create_communicator() const; - communicator create_communicator(size_t size, shared_ptr_class kvs) const; - communicator create_communicator(size_t size, - size_t rank, - shared_ptr_class kvs) const; - - template - vector_class create_communicators( - size_t comm_size, - const vector_class& local_devices, - ContextType& context, - shared_ptr_class kvs) const; - - template - vector_class create_communicators( - size_t comm_size, - const vector_class>& local_rank_device_map, - ContextType& context, - shared_ptr_class kvs) const; - - template - vector_class create_communicators( - size_t comm_size, - const map_class& local_rank_device_map, - ContextType& context, - shared_ptr_class kvs) const; - - vector_class split_device_communicators( - const vector_class>& attrs) const; - -private: - environment(); - - template - ccl_api_type create_postponed_api_type(args_type... args) const; -}; - -} /* ccl */ diff --git a/include/oneapi/ccl/ccl_type_traits.hpp b/include/oneapi/ccl/ccl_type_traits.hpp deleted file mode 100644 index 0964f66ec..000000000 --- a/include/oneapi/ccl/ccl_type_traits.hpp +++ /dev/null @@ -1,131 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#ifndef TRAITS_H_ -#define TRAITS_H_ - -#include -#include - -#ifdef CCL_ENABLE_SYCL -#include -#endif - -#include "oneapi/ccl/ccl_types.hpp" - -namespace ccl { -/** - * Base type-trait helpers for "unknown" types - */ -template -struct type_info { - static constexpr bool is_supported = false; - static constexpr bool is_class = false; -}; - -template -struct native_type_info { - static constexpr bool is_supported = false; - static constexpr bool is_class = false; -}; - -#define CCL_TYPE_TRAITS(ccl_type_id, dtype, dtype_size) \ - template <> \ - struct type_info \ - : public ccl_type_info_export { \ - static constexpr const char* name() { \ - return #dtype; \ - } \ - }; \ - template <> \ - struct native_type_info : public type_info {}; - -#define CCL_CLASS_TYPE_TRAITS(ccl_type_id, dtype, sizeof_dtype) \ - template <> \ - struct native_type_info \ - : public ccl_type_info_export { \ - static constexpr const char* name() { \ - return #dtype; \ - } \ - }; - -#define COMMA , - -/*struct bf16_impl -{ - uint16_t data; -} __attribute__((packed));*/ - -using bf16 = uint16_t; - -/** - * Enumeration of supported CCL API data types - */ -CCL_TYPE_TRAITS(ccl::datatype::int8, char, sizeof(char)) -CCL_TYPE_TRAITS(ccl::datatype::int32, int, sizeof(int)) -CCL_TYPE_TRAITS(ccl::datatype::bfloat16, bf16, sizeof(bf16)) -CCL_TYPE_TRAITS(ccl::datatype::float32, float, sizeof(float)) -CCL_TYPE_TRAITS(ccl::datatype::float64, double, sizeof(double)) -CCL_TYPE_TRAITS(ccl::datatype::int64, int64_t, sizeof(int64_t)) -CCL_TYPE_TRAITS(ccl::datatype::uint64, uint64_t, sizeof(uint64_t)) - -#ifdef CCL_ENABLE_SYCL -CCL_CLASS_TYPE_TRAITS(ccl::datatype::int8, cl::sycl::buffer, sizeof(char)) -CCL_CLASS_TYPE_TRAITS(ccl::datatype::int32, cl::sycl::buffer, sizeof(int)) -CCL_CLASS_TYPE_TRAITS(ccl::datatype::bfloat16, cl::sycl::buffer, sizeof(bf16)) -CCL_CLASS_TYPE_TRAITS(ccl::datatype::int64, cl::sycl::buffer, sizeof(int64_t)) -CCL_CLASS_TYPE_TRAITS(ccl::datatype::uint64, cl::sycl::buffer, sizeof(uint64_t)) -CCL_CLASS_TYPE_TRAITS(ccl::datatype::float32, cl::sycl::buffer, sizeof(float)) -CCL_CLASS_TYPE_TRAITS(ccl::datatype::float64, cl::sycl::buffer, sizeof(double)) -#endif //CCL_ENABLE_SYCL - -/** - * Checks for supporting @c type in ccl API - */ -template -constexpr bool is_supported() { - using clear_type = typename std::remove_pointer::type; - // static_assert(native_type_info::is_supported, "type is not supported by ccl API"); - return native_type_info::is_supported; -} - -/** - * Checks is @c type a class - */ -template -constexpr bool is_class() { - using clear_type = typename std::remove_pointer::type; - return native_type_info::is_class; -} - -/** - * SFINAE checks for supporting native type @c type in ccl API - */ -template -constexpr bool is_native_type_supported() { - return (not is_class() and is_supported()); -} - -/** - * SFINAE checks for supporting class @c type in ccl API - */ -template -constexpr bool is_class_supported() { - return (is_class() and is_supported()); -} - -} // namespace ccl -#include "oneapi/ccl/ccl_device_type_traits.hpp" -#endif //TRAITS_H_ diff --git a/include/oneapi/ccl/ccl_types.hpp b/include/oneapi/ccl/ccl_types.hpp deleted file mode 100644 index 5e8032ab8..000000000 --- a/include/oneapi/ccl/ccl_types.hpp +++ /dev/null @@ -1,281 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include -#include -#include "oneapi/ccl/ccl_config.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "oneapi/ccl/ccl_aliases.hpp" -#include "oneapi/ccl/ccl_exception.hpp" - -// TODO: tmp enums, refactor core code and remove them -/************************************************/ -typedef int ccl_status_t; -#define ccl_status_success (0) -#define ccl_status_out_of_resource (1) -#define ccl_status_invalid_arguments (2) -#define ccl_status_unimplemented (3) -#define ccl_status_runtime_error (4) -#define ccl_status_blocked_due_to_resize (5) -#define ccl_status_last_value (6) - -/** Resize action types. */ -typedef enum ccl_resize_action { - /* Wait additional changes for number of ranks */ - ccl_ra_wait = 0, - /* Run with current number of ranks */ - ccl_ra_run = 1, - /* Finalize work */ - ccl_ra_finalize = 2, -} ccl_resize_action_t; - -/* comm_size */ -typedef ccl_resize_action_t (*ccl_resize_fn_t)(size_t comm_size); - -/** Stream types. */ -typedef enum { - ccl_stream_host = 0, - ccl_stream_cpu = 1, - ccl_stream_gpu = 2, - - ccl_stream_last_value -} ccl_stream_type_t; -/************************************************/ - -namespace ccl { - -/** Library version description. */ -typedef struct { - unsigned int major; - unsigned int minor; - unsigned int update; - const char* product_status; - const char* build_date; - const char* full; -} library_version; - -/** - * Supported reduction operations - */ -enum class reduction : int { - sum = 0, - prod, - min, - max, - custom, - - last_value -}; - -/** - * Supported datatypes - */ -enum class datatype : int { - int8 = 0, - uint8, - int16, - uint16, - int32, - uint32, - int64, - uint64, - - float16, - float32, - float64, - - bfloat16, - - last_predefined = bfloat16 -}; - -string_class to_string(const ccl::datatype& dt); - -inline std::ostream& operator<<(std::ostream& os, const ccl::datatype& dt) { - os << ccl::to_string(dt); - return os; -} - -typedef struct { - const char* match_id; - const size_t offset; -} fn_context; - -/* Sparse coalesce modes */ -/* Use this variable to set sparse_allreduce coalescing mode: - ccl_sparse_coalesce_regular run regular coalesce funtion; - ccl_sparse_coalesce_disable disables coalesce function in sparse_allreduce, - allgathered data is returned; - ccl_sparse_coalesce_keep_precision on every local reduce bf16 data is - converted to fp32, reduced and then converted - back to bf16. -*/ - -enum class sparse_coalesce_mode : int { regular = 0, disable = 1, keep_precision = 2 }; - -/* comm_size */ -typedef ccl_resize_action_t (*ccl_resize_fn_t)(size_t comm_size); - -/* in_buf, in_count, in_dtype, out_buf, out_count, out_dtype, context */ -typedef void (*prologue_fn)(const void*, - size_t, - ccl::datatype, - void**, - size_t*, - ccl::datatype*, - const ccl::fn_context*); - -/* in_buf, in_count, in_dtype, out_buf, out_count, out_dtype, context */ -typedef void (*epilogue_fn)(const void*, - size_t, - ccl::datatype, - void*, - size_t*, - ccl::datatype, - const ccl::fn_context*); - -/* in_buf, in_count, inout_buf, out_count, dtype, context */ -typedef void ( - *reduction_fn)(const void*, size_t, void*, size_t*, ccl::datatype, const ccl::fn_context*); - -/* idx_buf, idx_count, idx_dtype, val_buf, val_count, val_dtype, user_context */ -typedef void (*sparse_allreduce_completion_fn)(const void*, - size_t, - ccl::datatype, - const void*, - size_t, - ccl::datatype, - const void*); - -/* idx_count, idx_dtype, val_count, val_dtype, user_context, out_idx_buf, out_val_buf */ -typedef void (*sparse_allreduce_alloc_fn)(size_t, - ccl::datatype, - size_t, - ccl::datatype, - const void*, - void**, - void**); - -// using datatype_attr_t = ccl_datatype_attr_t; - -/** - * Supported CL backend types - */ -enum class cl_backend_type : int { - empty_backend = 0x0, - dpcpp_sycl = 0x1, - l0 = 0x2, - dpcpp_sycl_l0 = 0x3, - - last_value -}; -/** - * Supported stream types - */ -enum class stream_type : int { - host = 0, - cpu, - gpu, - - last_value -}; - -/** - * Type traits, which describes how-to types would be interpretered by ccl API - */ -template -struct ccl_type_info_export { - using native_type = ntype_t; - using ccl_type = std::integral_constant; - static constexpr size_t size = size_of_type; - static constexpr ccl::datatype ccl_type_value = ccl_type::value; - static constexpr datatype ccl_datatype_value = static_cast(ccl_type_value); - static constexpr bool is_class = iclass; - static constexpr bool is_supported = supported; -}; - -struct ccl_empty_attr { - static ccl::library_version version; - - template - static attr create_empty(); -}; - -/** - * API object attributes traits - */ -namespace info { -template -struct param_traits {}; - -} //namespace info -} // namespace ccl - -// TODO: tmp struct, refactor core code and remove it -/*********************************************************/ - -/** Extendable list of collective attributes. */ -typedef struct { - /** - * Callbacks into application code - * for pre-/post-processing data - * and custom reduction operation - */ - ccl::prologue_fn prologue_fn; - ccl::epilogue_fn epilogue_fn; - ccl::reduction_fn reduction_fn; - - /* Priority for collective operation */ - size_t priority; - - /* Blocking/non-blocking */ - int synchronous; - - /* Persistent/non-persistent */ - int to_cache; - - /* Treat buffer as vector/regular - applicable for allgatherv only */ - int vector_buf; - - /** - * Id of the operation. If specified, new communicator will be created and collective - * operations with the same @b match_id will be executed in the same order. - */ - const char* match_id; - - /* Sparse allreduce specific */ - ccl::sparse_allreduce_completion_fn sparse_allreduce_completion_fn; - ccl::sparse_allreduce_alloc_fn sparse_allreduce_alloc_fn; - const void* sparse_allreduce_fn_ctx; - ccl::sparse_coalesce_mode sparse_coalesce_mode; - -} ccl_coll_attr_t; - -#include "oneapi/ccl/ccl_device_types.hpp" diff --git a/include/oneapi/ccl/ccl_coll_attr.hpp b/include/oneapi/ccl/coll_attr.hpp similarity index 67% rename from include/oneapi/ccl/ccl_coll_attr.hpp rename to include/oneapi/ccl/coll_attr.hpp index 16c0ee155..30d843b45 100644 --- a/include/oneapi/ccl/ccl_coll_attr.hpp +++ b/include/oneapi/ccl/coll_attr.hpp @@ -20,6 +20,9 @@ #endif namespace ccl { +namespace detail { +class environment; +} class ccl_allgatherv_attr_impl_t; class ccl_allreduce_attr_impl_t; @@ -31,6 +34,8 @@ class ccl_reduce_attr_impl_t; class ccl_reduce_scatter_attr_impl_t; class ccl_sparse_allreduce_attr_impl_t; +namespace v1 { + struct ccl_empty_attr; /** @@ -65,30 +70,29 @@ class allgatherv_attr : public ccl_api_base_copyable()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); template ()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); /** * Get specific attribute value by @attrId */ template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; private: - friend class environment; - friend struct ccl_empty_attr; + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; allgatherv_attr( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); }; /** @@ -123,31 +127,30 @@ class allreduce_attr : public ccl_api_base_copyable()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); template ()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); /** * Get specific attribute value by @attrId */ template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; private: - friend class environment; - friend struct ccl_empty_attr; + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; allreduce_attr( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); }; /** @@ -181,31 +184,30 @@ class alltoall_attr : public ccl_api_base_copyable()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); template ()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); /** * Get specific attribute value by @attrId */ template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; private: - friend class environment; - friend struct ccl_empty_attr; + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; alltoall_attr( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); }; /** @@ -240,31 +242,30 @@ class alltoallv_attr : public ccl_api_base_copyable()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); template ()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); /** * Get specific attribute value by @attrId */ template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; private: - friend class environment; - friend struct ccl_empty_attr; + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; alltoallv_attr( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); }; /** @@ -298,30 +299,30 @@ class barrier_attr : public ccl_api_base_copyable()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); template ()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); /** * Get specific attribute value by @attrId */ template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; private: - friend class environment; - friend struct ccl_empty_attr; - barrier_attr(const typename details::ccl_api_type_attr_traits::type& - version); + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; + barrier_attr( + const typename detail::ccl_api_type_attr_traits::type& version); ; }; @@ -357,31 +358,30 @@ class broadcast_attr : public ccl_api_base_copyable()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); template ()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); /** * Get specific attribute value by @attrId */ template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; private: - friend class environment; - friend struct ccl_empty_attr; + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; broadcast_attr( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); }; /** @@ -415,30 +415,30 @@ class reduce_attr : public ccl_api_base_copyable()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); template ()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); /** * Get specific attribute value by @attrId */ template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; private: - friend class environment; - friend struct ccl_empty_attr; - reduce_attr(const typename details::ccl_api_type_attr_traits::type& - version); + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; + reduce_attr( + const typename detail::ccl_api_type_attr_traits::type& version); }; /** @@ -473,31 +473,30 @@ class reduce_scatter_attr : public ccl_api_base_copyable()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); template ()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); /** * Get specific attribute value by @attrId */ template - const typename details::ccl_api_type_attr_traits::return_type& + const typename detail::ccl_api_type_attr_traits::return_type& get() const; template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; private: - friend class environment; - friend struct ccl_empty_attr; + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; reduce_scatter_attr( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); ; }; @@ -533,31 +532,30 @@ class sparse_allreduce_attr : public ccl_api_base_copyable()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); template ()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); /** * Get specific attribute value by @attrId */ template - const typename details::ccl_api_type_attr_traits::return_type& + const typename detail::ccl_api_type_attr_traits::return_type& get() const; template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; private: - friend class environment; - friend struct ccl_empty_attr; + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; sparse_allreduce_attr( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); ; }; @@ -579,59 +577,82 @@ extern barrier_attr default_barrier_attr; */ template constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); + -> detail::attr_value_triple { + return detail::attr_value_triple(v); } template constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); + -> detail::attr_value_triple { + return detail::attr_value_triple(v); } template constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); + -> detail::attr_value_triple { + return detail::attr_value_triple(v); } template constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); + -> detail::attr_value_triple { + return detail::attr_value_triple(v); } template constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); + -> detail::attr_value_triple { + return detail::attr_value_triple(v); } template -constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); +constexpr auto attr_val(value_type v) -> detail::attr_value_triple { + return detail::attr_value_triple(v); } template constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); + -> detail::attr_value_triple { + return detail::attr_value_triple(v); } template constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); + -> detail::attr_value_triple { + return detail::attr_value_triple(v); } template constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); + -> detail::attr_value_triple { + return detail::attr_value_triple(v); } -/* TODO temporary function for UT compilation: would be part of ccl::environment in final*/ +/* TODO temporary function for UT compilation: would be part of detail::environment in final*/ template coll_attribute_type create_coll_attr(attr_value_pair_t&&... avps); + +} // namespace v1 + +using v1::allgatherv_attr; +using v1::allreduce_attr; +using v1::alltoall_attr; +using v1::alltoallv_attr; +using v1::broadcast_attr; +using v1::reduce_attr; +using v1::reduce_scatter_attr; +using v1::sparse_allreduce_attr; +using v1::barrier_attr; +using v1::attr_val; + +using v1::default_allgatherv_attr; +using v1::default_allreduce_attr; +using v1::default_alltoall_attr; +using v1::default_alltoallv_attr; +using v1::default_broadcast_attr; +using v1::default_reduce_attr; +using v1::default_reduce_scatter_attr; +using v1::default_sparse_allreduce_attr; +using v1::default_barrier_attr; + } // namespace ccl diff --git a/include/oneapi/ccl/ccl_coll_attr_ids.hpp b/include/oneapi/ccl/coll_attr_ids.hpp similarity index 89% rename from include/oneapi/ccl/ccl_coll_attr_ids.hpp rename to include/oneapi/ccl/coll_attr_ids.hpp index 149561f64..036ab68bb 100644 --- a/include/oneapi/ccl/ccl_coll_attr_ids.hpp +++ b/include/oneapi/ccl/coll_attr_ids.hpp @@ -21,6 +21,8 @@ namespace ccl { +namespace v1 { + /** * Common operation attributes id */ @@ -32,9 +34,6 @@ enum class operation_attr_id : int { synchronous, match_id, - prologue_fn, - epilogue_fn, - last_value }; @@ -110,4 +109,18 @@ enum class barrier_attr_id : int { last_value }; + +} // namespace v1 + +using v1::operation_attr_id; +using v1::allgatherv_attr_id; +using v1::allreduce_attr_id; +using v1::alltoall_attr_id; +using v1::alltoallv_attr_id; +using v1::broadcast_attr_id; +using v1::reduce_attr_id; +using v1::reduce_scatter_attr_id; +using v1::sparse_allreduce_attr_id; +using v1::barrier_attr_id; + } // namespace ccl diff --git a/include/oneapi/ccl/ccl_coll_attr_ids_traits.hpp b/include/oneapi/ccl/coll_attr_ids_traits.hpp similarity index 89% rename from include/oneapi/ccl/ccl_coll_attr_ids_traits.hpp rename to include/oneapi/ccl/coll_attr_ids_traits.hpp index 883e66db7..d3d5df484 100644 --- a/include/oneapi/ccl/ccl_coll_attr_ids_traits.hpp +++ b/include/oneapi/ccl/coll_attr_ids_traits.hpp @@ -21,7 +21,7 @@ namespace ccl { -namespace details { +namespace detail { template class function_holder { public: @@ -42,18 +42,6 @@ struct ccl_api_type_attr_traits { using return_type = type; }; -template <> -struct ccl_api_type_attr_traits { - using type = ccl::prologue_fn; - using return_type = function_holder; -}; - -template <> -struct ccl_api_type_attr_traits { - using type = ccl::epilogue_fn; - using return_type = function_holder; -}; - template <> struct ccl_api_type_attr_traits { using type = size_t; @@ -144,5 +132,6 @@ struct ccl_api_type_attr_traits { +public: + using base_t = + ccl_api_base_copyable; + + /** + * Declare PIMPL type + */ + using impl_value_t = typename base_t::impl_value_t; + + /** + * Declare implementation type + */ + using impl_t = typename impl_value_t::element_type; + + comm_attr& operator=(const comm_attr& src); + comm_attr& operator=(comm_attr&& src); + comm_attr(comm_attr&& src); + comm_attr(const comm_attr& src); + comm_attr(ccl_empty_attr); + ~comm_attr() noexcept; + + /** + * Set specific value for selft attribute by @attrId. + * Previous attibute value would be returned + */ + template ()>::type*/> + Value set(const Value& v); + + /** + * Get specific attribute value by @attrId + */ + template + const typename detail::ccl_api_type_attr_traits::type& get() const; + + template + bool is_valid() const noexcept; + +private: + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; + + comm_attr(const typename detail::ccl_api_type_attr_traits::return_type& + version); +}; +extern comm_attr default_comm_attr; + +template +constexpr auto attr_val(value_type v) -> detail::attr_value_triple { + return detail::attr_value_triple(v); +} + +} // namespace v1 + +using v1::comm_attr; +using v1::default_comm_attr; +using v1::attr_val; + +} // namespace ccl diff --git a/include/oneapi/ccl/comm_attr_ids.hpp b/include/oneapi/ccl/comm_attr_ids.hpp new file mode 100644 index 000000000..3b926f3fa --- /dev/null +++ b/include/oneapi/ccl/comm_attr_ids.hpp @@ -0,0 +1,36 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace v1 { + +enum class comm_attr_id : int { + version, + + last_value +}; + +} // namespace v1 + +using v1::comm_attr_id; + +} // namespace ccl diff --git a/include/oneapi/ccl/comm_attr_ids_traits.hpp b/include/oneapi/ccl/comm_attr_ids_traits.hpp new file mode 100644 index 000000000..9fec8ab62 --- /dev/null +++ b/include/oneapi/ccl/comm_attr_ids_traits.hpp @@ -0,0 +1,34 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace detail { + +template <> +struct ccl_api_type_attr_traits { + using type = ccl::library_version; + using return_type = type; +}; + +} // namespace detail + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_comm_split_attr.hpp b/include/oneapi/ccl/comm_split_attr.hpp similarity index 78% rename from include/oneapi/ccl/ccl_comm_split_attr.hpp rename to include/oneapi/ccl/comm_split_attr.hpp index b84236635..a7c50b66c 100644 --- a/include/oneapi/ccl/ccl_comm_split_attr.hpp +++ b/include/oneapi/ccl/comm_split_attr.hpp @@ -13,74 +13,96 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -#ifndef CCL_PRODUCT_FULL -#error "Do not include this file directly. Please include 'ccl.hpp'" -#endif - -namespace ccl { - -class ccl_comm_split_attr_impl; -struct ccl_empty_attr; - -/** - * Device attributes - */ -class comm_split_attr : public ccl_api_base_copyable { -public: - using base_t = ccl_api_base_copyable; - - /** - * Declare PIMPL type - */ - using impl_value_t = typename base_t::impl_value_t; - - /** - * Declare implementation type - */ - using impl_t = typename impl_value_t::element_type; - - comm_split_attr& operator=(const comm_split_attr& src); - comm_split_attr& operator=(comm_split_attr&& src); - comm_split_attr(comm_split_attr&& src); - comm_split_attr(const comm_split_attr& src); - comm_split_attr(ccl_empty_attr); - ~comm_split_attr() noexcept; - - /** - * Set specific value for selft attribute by @attrId. - * Previous attibute value would be returned - */ - template ()>::type*/> - Value set(const Value& v); - - /** - * Get specific attribute value by @attrId - */ - template - const typename details::ccl_api_type_attr_traits::type& get() const; - - template - bool is_valid() const noexcept; - -private: - friend class environment; - comm_split_attr( - const typename details::ccl_api_type_attr_traits::type& - version); -}; - -template -constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); -} -} // namespace ccl +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { +namespace detail { +class environment; +} + +class ccl_comm_split_attr_impl; + +namespace v1 { + +struct ccl_empty_attr; + +/** + * Device attributes + */ +class comm_split_attr : public ccl_api_base_copyable { +public: + using base_t = ccl_api_base_copyable; + + /** + * Declare PIMPL type + */ + using impl_value_t = typename base_t::impl_value_t; + + /** + * Declare implementation type + */ + using impl_t = typename impl_value_t::element_type; + + comm_split_attr& operator=(const comm_split_attr& src); + comm_split_attr& operator=(comm_split_attr&& src); + comm_split_attr(comm_split_attr&& src); + comm_split_attr(const comm_split_attr& src); + comm_split_attr(ccl_empty_attr); + ~comm_split_attr() noexcept; + + /** + * Set specific value for selft attribute by @attrId. + * Previous attibute value would be returned + */ + template ()>::type*/> + Value set(const Value& v); + + /** + * Get specific attribute value by @attrId + */ + template + const typename detail::ccl_api_type_attr_traits::type& get() const; + + template + bool is_valid() const noexcept; + +private: + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; + comm_split_attr( + const typename detail::ccl_api_type_attr_traits::type& + version); +}; + +/** + * Declare extern empty attributes + */ +extern comm_split_attr default_comm_split_attr; + +/** + * Fabric helpers + */ +template +constexpr auto attr_val(value_type v) + -> detail::attr_value_triple { + return detail::attr_value_triple(v); +} + +} // namespace v1 + +using v1::comm_split_attr; +using v1::default_comm_split_attr; +using v1::attr_val; + +} // namespace ccl diff --git a/include/oneapi/ccl/comm_split_attr_ids.hpp b/include/oneapi/ccl/comm_split_attr_ids.hpp new file mode 100644 index 000000000..934b83630 --- /dev/null +++ b/include/oneapi/ccl/comm_split_attr_ids.hpp @@ -0,0 +1,46 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace v1 { + +enum class comm_split_attr_id : int { + version, + + color, + group, + + last_value +}; + +enum class split_group : int { + cluster, + + last_value +}; + +} // namespace v1 + +using v1::comm_split_attr_id; +using v1::split_group; + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_comm_split_attr_ids_traits.hpp b/include/oneapi/ccl/comm_split_attr_ids_traits.hpp similarity index 91% rename from include/oneapi/ccl/ccl_comm_split_attr_ids_traits.hpp rename to include/oneapi/ccl/comm_split_attr_ids_traits.hpp index ac4c630e9..3603aeb27 100644 --- a/include/oneapi/ccl/ccl_comm_split_attr_ids_traits.hpp +++ b/include/oneapi/ccl/comm_split_attr_ids_traits.hpp @@ -13,31 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -#ifndef CCL_PRODUCT_FULL -#error "Do not include this file directly. Please include 'ccl.hpp'" -#endif - -namespace ccl { - -namespace details { - -template <> -struct ccl_api_type_attr_traits { - using type = ccl::library_version; -}; - -template <> -struct ccl_api_type_attr_traits { - using type = int; -}; - -template <> -struct ccl_api_type_attr_traits { - using type = group_split_type; -}; - -} // namespace details - -} // namespace ccl +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace detail { + +template <> +struct ccl_api_type_attr_traits { + using type = ccl::library_version; +}; + +template <> +struct ccl_api_type_attr_traits { + using type = int; +}; + +template <> +struct ccl_api_type_attr_traits { + using type = split_group; +}; + +} // namespace detail + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_communicator.hpp b/include/oneapi/ccl/communicator.hpp similarity index 63% rename from include/oneapi/ccl/ccl_communicator.hpp rename to include/oneapi/ccl/communicator.hpp index 1c1f1af47..d6ee6e477 100644 --- a/include/oneapi/ccl/ccl_communicator.hpp +++ b/include/oneapi/ccl/communicator.hpp @@ -20,20 +20,32 @@ #endif namespace ccl { -class event; -class kvs_interface; -using rank_t = size_t; +namespace detail { +class environment; +} struct communicator_interface; + +template +struct comm_impl_dispatch_selector; + +class comm_group; + +namespace v1 { +class context; +class device; +class kvs_interface; +struct impl_dispatch; + /** - * A device communicator that permits device communication operations + * A communicator that permits communication operations * Has no defined public constructor. - * Use ccl::environment::create_device_communicator for communicator objects creation. + * Use ccl::create_communicator for communicator objects creation. */ class communicator final : public ccl_api_base_movable { + direct_access_policy, + communicator_interface, + std::shared_ptr> { public: using base_t = ccl_api_base_movable stream create_stream(attr_value_pair_t&&... avps) { @@ -98,43 +108,49 @@ class communicator final : public ccl_api_base_movable - friend struct comm_impl_dispatch_selector; + template + friend struct ccl::comm_impl_dispatch_selector; communicator(impl_value_t&& impl); // factory methods template static vector_class create_communicators( - size_t comm_size, + int comm_size, const vector_class& local_devices, - ContextType& context, + const ContextType& context, shared_ptr_class kvs); template static vector_class create_communicators( - size_t comm_size, - const vector_class>& local_rank_device_map, - ContextType& context, + int comm_size, + const vector_class>& local_rank_device_map, + const ContextType& context, shared_ptr_class kvs); template static vector_class create_communicators( - size_t comm_size, - const map_class& local_rank_device_map, - ContextType& context, + int comm_size, + const map_class& local_rank_device_map, + const ContextType& context, shared_ptr_class kvs); - static communicator create_communicator(); - static communicator create_communicator(size_t size, - shared_ptr_class kvs); - static communicator create_communicator(size_t size, - size_t rank, - shared_ptr_class kvs); + static communicator create_communicator(const comm_attr& attr); + static communicator create_communicator(int size, + shared_ptr_class kvs, + const comm_attr& attr); + static communicator create_communicator(int size, + int rank, + shared_ptr_class kvs, + const comm_attr& attr); }; +} // namespace v1 + +using v1::communicator; + } // namespace ccl diff --git a/include/oneapi/ccl/ccl_config.h.in b/include/oneapi/ccl/config.h.in similarity index 81% rename from include/oneapi/ccl/ccl_config.h.in rename to include/oneapi/ccl/config.h.in index 8ca113864..82444621d 100644 --- a/include/oneapi/ccl/ccl_config.h.in +++ b/include/oneapi/ccl/config.h.in @@ -22,10 +22,6 @@ #cmakedefine CCL_PRODUCT_BUILD_DATE "@CCL_PRODUCT_BUILD_DATE@" #cmakedefine CCL_PRODUCT_FULL "@CCL_PRODUCT_FULL@" - -/* Configuration settings for multi GPU extension support*/ -#cmakedefine MULTI_GPU_SUPPORT - /* Auto-generated configuration settings for SYCL support */ #cmakedefine CCL_ENABLE_SYCL @@ -33,6 +29,5 @@ @CCL_ENABLE_SYCL_CHECK_CONTRACT@ #endif -#define CCL_ENABLE_SYCL_V @CCL_ENABLE_SYCL_V@ -#define CCL_ENABLE_SYCL_TRUE 1 -#define CCL_ENABLE_SYCL_FALSE 0 +/* Auto-generated configuration settings for multi GPU support*/ +#cmakedefine MULTI_GPU_SUPPORT diff --git a/include/oneapi/ccl/ccl_context.hpp b/include/oneapi/ccl/context.hpp similarity index 56% rename from include/oneapi/ccl/ccl_context.hpp rename to include/oneapi/ccl/context.hpp index 0edd611bb..45ab52d19 100644 --- a/include/oneapi/ccl/ccl_context.hpp +++ b/include/oneapi/ccl/context.hpp @@ -21,10 +21,16 @@ class ccl_context_impl; namespace ccl { +namespace detail { +class environment; +} + +namespace v1 { +class communicator; /** * A context object is an abstraction over CPU/GPU context - * Has no defined public constructor. Use ccl::environment::create_context + * Has no defined public constructor. Use ccl::create_context * for context objects creation */ /** @@ -47,30 +53,35 @@ class context : public ccl_api_base_copyable::return_type; + using native_t = + typename detail::ccl_api_type_attr_traits::return_type; context(context&& src); context(const context& src); context& operator=(const context& src); context& operator=(context&& src); ~context(); + bool operator==(const context& rhs) const noexcept; + bool operator!=(const context& rhs) const noexcept; + bool operator<(const context& rhs) const noexcept; + /** * Get specific attribute value by @attrId */ template - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; /** * Get native context object */ - native_t& get_native(); - const native_t& get_native() const; + native_t& get_native(); + const native_t& get_native() const; + private: - friend class environment; - friend class communicator; - friend class device_context_communicator; + friend class ccl::detail::environment; + friend class ccl::v1::communicator; context(impl_value_t&& impl); /** @@ -79,27 +90,33 @@ class context : public ccl_api_base_copyable()>::type*/> - typename ccl::details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); void build_from_params(); - context(const typename details::ccl_api_type_attr_traits::type& version); + context( + const typename detail::ccl_api_type_attr_traits::type& version); /** * Factory methods */ - template ()>::type> - static context create_context(device_context_type&& native_device_context); + template ()>::type> + static context create_context(context_type&& native_context); - template - static context create_context_from_attr(device_context_handle_type& native_device_context_handle, - attr_value_pair_t&&... avps); + template + static context create_context_from_attr(context_handle_type& native_context_handle, + attr_value_pair_t&&... avps); }; template -constexpr auto attr_val(value_type v) -> details::attr_value_tripple { - return details::attr_value_tripple(v); +constexpr auto attr_val(value_type v) -> detail::attr_value_triple { + return detail::attr_value_triple(v); } +} // namespace v1 + +using v1::context; +using v1::attr_val; + } // namespace ccl diff --git a/include/oneapi/ccl/ccl_context_attr_ids.hpp b/include/oneapi/ccl/context_attr_ids.hpp similarity index 93% rename from include/oneapi/ccl/ccl_context_attr_ids.hpp rename to include/oneapi/ccl/context_attr_ids.hpp index db24326e1..7facff3e7 100644 --- a/include/oneapi/ccl/ccl_context_attr_ids.hpp +++ b/include/oneapi/ccl/context_attr_ids.hpp @@ -20,6 +20,9 @@ #endif namespace ccl { + +namespace v1 { + /** * Context attribute ids */ @@ -31,4 +34,8 @@ enum class context_attr_id : int { last_value }; +} // namespace v1 + +using v1::context_attr_id; + } // namespace ccl diff --git a/include/oneapi/ccl/ccl_context_attr_ids_traits.hpp b/include/oneapi/ccl/context_attr_ids_traits.hpp similarity index 90% rename from include/oneapi/ccl/ccl_context_attr_ids_traits.hpp rename to include/oneapi/ccl/context_attr_ids_traits.hpp index ccae970af..b85c2ef8a 100644 --- a/include/oneapi/ccl/ccl_context_attr_ids_traits.hpp +++ b/include/oneapi/ccl/context_attr_ids_traits.hpp @@ -20,7 +20,8 @@ #endif namespace ccl { -namespace details { + +namespace detail { /** * Traits for context attributes specializations @@ -39,8 +40,10 @@ struct ccl_api_type_attr_traits { template <> struct ccl_api_type_attr_traits { - using type = typename unified_device_context_type::ccl_native_t; + using type = typename unified_context_type::ccl_native_t; using return_type = type; }; -} -} + +} // namespace detail + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_datatype_attr.hpp b/include/oneapi/ccl/datatype_attr.hpp similarity index 78% rename from include/oneapi/ccl/ccl_datatype_attr.hpp rename to include/oneapi/ccl/datatype_attr.hpp index da42d8cc1..4ed1863b2 100644 --- a/include/oneapi/ccl/ccl_datatype_attr.hpp +++ b/include/oneapi/ccl/datatype_attr.hpp @@ -13,67 +13,77 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -#ifndef CCL_PRODUCT_FULL -#error "Do not include this file directly. Please include 'ccl.hpp'" -#endif - -namespace ccl { - -class ccl_datatype_attr_impl; - -class datatype_attr : public ccl_api_base_copyable { -public: - using base_t = - ccl_api_base_copyable; - - /** - * Declare PIMPL type - */ - using impl_value_t = typename base_t::impl_value_t; - - /** - * Declare implementation type - */ - using impl_t = typename impl_value_t::element_type; - - datatype_attr& operator=(const datatype_attr& src); - datatype_attr& operator=(datatype_attr&& src); - datatype_attr(datatype_attr&& src); - datatype_attr(const datatype_attr& src); - ~datatype_attr() noexcept; - - /** - * Set specific value for selft attribute by @attrId. - * Previous attibute value would be returned - */ - template ()>::return_type*/> - Value set(const Value& v); - - /** - * Get specific attribute value by @attrId - */ - template - const typename details::ccl_api_type_attr_traits::return_type& get() - const; - -private: - friend class environment; - datatype_attr( - const typename details::ccl_api_type_attr_traits::return_type& - version); -}; - -template -constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); -} - -} // namespace ccl +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { +namespace detail { +class environment; +} + +class ccl_datatype_attr_impl; + +namespace v1 { + +class datatype_attr : public ccl_api_base_copyable { +public: + using base_t = + ccl_api_base_copyable; + + /** + * Declare PIMPL type + */ + using impl_value_t = typename base_t::impl_value_t; + + /** + * Declare implementation type + */ + using impl_t = typename impl_value_t::element_type; + + datatype_attr& operator=(const datatype_attr& src); + datatype_attr& operator=(datatype_attr&& src); + datatype_attr(datatype_attr&& src); + datatype_attr(const datatype_attr& src); + ~datatype_attr() noexcept; + + /** + * Set specific value for selft attribute by @attrId. + * Previous attibute value would be returned + */ + template ()>::return_type*/> + Value set(const Value& v); + + /** + * Get specific attribute value by @attrId + */ + template + const typename detail::ccl_api_type_attr_traits::return_type& get() + const; + +private: + friend class ccl::detail::environment; + datatype_attr( + const typename detail::ccl_api_type_attr_traits::return_type& + version); +}; + +template +constexpr auto attr_val(value_type v) + -> detail::attr_value_triple { + return detail::attr_value_triple(v); +} + +} // namespace v1 + +using v1::datatype_attr; +using v1::attr_val; + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_datatype_attr_ids.hpp b/include/oneapi/ccl/datatype_attr_ids.hpp similarity index 90% rename from include/oneapi/ccl/ccl_datatype_attr_ids.hpp rename to include/oneapi/ccl/datatype_attr_ids.hpp index 48a33aa85..3f07a8a81 100644 --- a/include/oneapi/ccl/ccl_datatype_attr_ids.hpp +++ b/include/oneapi/ccl/datatype_attr_ids.hpp @@ -13,20 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -#ifndef CCL_PRODUCT_FULL -#error "Do not include this file directly. Please include 'ccl.hpp'" -#endif - -namespace ccl { - -enum class datatype_attr_id : int { - version, - - size, - - last_value -}; - -} +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace v1 { + +enum class datatype_attr_id : int { + version, + + size, + + last_value +}; + +} // namespace v1 + +using v1::datatype_attr_id; + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_datatype_attr_ids_traits.hpp b/include/oneapi/ccl/datatype_attr_ids_traits.hpp similarity index 93% rename from include/oneapi/ccl/ccl_datatype_attr_ids_traits.hpp rename to include/oneapi/ccl/datatype_attr_ids_traits.hpp index 9f2c2ed18..ba67cebb3 100644 --- a/include/oneapi/ccl/ccl_datatype_attr_ids_traits.hpp +++ b/include/oneapi/ccl/datatype_attr_ids_traits.hpp @@ -13,27 +13,28 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -#ifndef CCL_PRODUCT_FULL -#error "Do not include this file directly. Please include 'ccl.hpp'" -#endif - -namespace ccl { -namespace details { - -template <> -struct ccl_api_type_attr_traits { - using type = ccl::library_version; - using return_type = type; -}; - -template <> -struct ccl_api_type_attr_traits { - using type = size_t; - using return_type = type; -}; - -} // namespace details - -} // namespace ccl +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace detail { + +template <> +struct ccl_api_type_attr_traits { + using type = ccl::library_version; + using return_type = type; +}; + +template <> +struct ccl_api_type_attr_traits { + using type = size_t; + using return_type = type; +}; + +} // namespace detail + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_device.hpp b/include/oneapi/ccl/device.hpp similarity index 60% rename from include/oneapi/ccl/ccl_device.hpp rename to include/oneapi/ccl/device.hpp index c38fa5470..2f18617ec 100644 --- a/include/oneapi/ccl/ccl_device.hpp +++ b/include/oneapi/ccl/device.hpp @@ -21,10 +21,16 @@ class ccl_device_impl; namespace ccl { +namespace detail { +class environment; +} + +namespace v1 { +class communicator; /** * A device object is an abstraction over CPU/GPU device - * Has no defined public constructor. Use ccl::environment::create_device + * Has no defined public constructor. Use ccl::create_device * for device objects creation */ /** @@ -47,8 +53,9 @@ class device : public ccl_api_base_copyable::return_type; + using native_t = + typename detail::ccl_api_type_attr_traits::return_type; device(device&& src); device(const device& src); @@ -56,21 +63,26 @@ class device : public ccl_api_base_copyable - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; /** * Get native device object */ - native_t& get_native(); - const native_t& get_native() const; + native_t& get_native(); + const native_t& get_native() const; + private: - friend class environment; - friend class communicator; + friend class ccl::detail::environment; + friend class ccl::v1::communicator; device(impl_value_t&& impl); /** @@ -79,10 +91,10 @@ class device : public ccl_api_base_copyable()>::type*/> - typename ccl::details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); void build_from_params(); - device(const typename details::ccl_api_type_attr_traits::type& version); /** @@ -94,21 +106,28 @@ class device : public ccl_api_base_copyable static device create_device_from_attr(device_handle_type& native_device_handle, - attr_value_pair_t&&... avps); + attr_value_pair_t&&... avps); }; template -constexpr auto attr_val(value_type v) -> details::attr_value_tripple { - return details::attr_value_tripple(v); +constexpr auto attr_val(value_type v) -> detail::attr_value_triple { + return detail::attr_value_triple(v); } -template -using rank_device_pair_t = ccl::pair_class::type>::type>; +template +using rank_device_pair_t = ccl::pair_class< + size_t, + typename std::remove_reference::type>::type>; template -constexpr auto attr_val(size_t rank, device_value_type&& v) - -> rank_device_pair_t{ - return rank_device_pair_t{rank, std::forward(v)}; +constexpr auto attr_val(int rank, device_value_type&& v) -> rank_device_pair_t { + return rank_device_pair_t{ rank, std::forward(v) }; } +} // namespace v1 + +using v1::device; +using v1::attr_val; +using v1::rank_device_pair_t; + } // namespace ccl diff --git a/include/oneapi/ccl/ccl_device_attr_ids.hpp b/include/oneapi/ccl/device_attr_ids.hpp similarity index 93% rename from include/oneapi/ccl/ccl_device_attr_ids.hpp rename to include/oneapi/ccl/device_attr_ids.hpp index cb5644d9e..a3a4866dd 100644 --- a/include/oneapi/ccl/ccl_device_attr_ids.hpp +++ b/include/oneapi/ccl/device_attr_ids.hpp @@ -20,6 +20,9 @@ #endif namespace ccl { + +namespace v1 { + /** * Device attribute ids */ @@ -31,4 +34,8 @@ enum class device_attr_id : int { last_value }; +} // namespace v1 + +using v1::device_attr_id; + } // namespace ccl diff --git a/include/oneapi/ccl/ccl_device_attr_ids_traits.hpp b/include/oneapi/ccl/device_attr_ids_traits.hpp similarity index 95% rename from include/oneapi/ccl/ccl_device_attr_ids_traits.hpp rename to include/oneapi/ccl/device_attr_ids_traits.hpp index c44af0c41..3e0ba6ed1 100644 --- a/include/oneapi/ccl/ccl_device_attr_ids_traits.hpp +++ b/include/oneapi/ccl/device_attr_ids_traits.hpp @@ -20,7 +20,8 @@ #endif namespace ccl { -namespace details { + +namespace detail { /** * Traits for device attributes specializations @@ -42,5 +43,7 @@ struct ccl_api_type_attr_traits { using type = typename unified_device_type::ccl_native_t; using return_type = type; }; -} -} + +} // namespace detail + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_device_type_traits.hpp b/include/oneapi/ccl/device_type_traits.hpp similarity index 80% rename from include/oneapi/ccl/ccl_device_type_traits.hpp rename to include/oneapi/ccl/device_type_traits.hpp index 08bbad9f7..90ce9d296 100644 --- a/include/oneapi/ccl/ccl_device_type_traits.hpp +++ b/include/oneapi/ccl/device_type_traits.hpp @@ -19,11 +19,13 @@ #error "Do not include this file directly. Please include 'ccl_type_traits.hpp'" #endif -#include "oneapi/ccl/native_device_api/export_api.hpp" +#include "oneapi/ccl/native_device_api/export_api.hpp" namespace ccl { -#define SUPPORTED_KERNEL_NATIVE_DATA_TYPES char, int, float, ccl::bf16, double, int64_t, uint64_t +#define SUPPORTED_KERNEL_NATIVE_DATA_TYPES \ + int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t, float, double, \ + ccl::bfloat16 template constexpr bool is_stream_supported() { @@ -40,13 +42,13 @@ constexpr bool is_event_supported() { template constexpr bool is_device_supported() { return api_type_info::type>::type>::type>::is_supported(); + typename std::remove_reference::type>::type>::type>::is_supported(); } template constexpr bool is_context_supported() { return api_type_info::type>::type>::type>::is_supported(); + typename std::remove_reference::type>::type>::type>::is_supported(); } /** @@ -54,7 +56,7 @@ constexpr bool is_context_supported() { */ API_CLASS_TYPE_INFO(empty_t); API_CLASS_TYPE_INFO(typename unified_device_type::ccl_native_t) -API_CLASS_TYPE_INFO(typename unified_device_context_type::ccl_native_t); +API_CLASS_TYPE_INFO(typename unified_context_type::ccl_native_t); API_CLASS_TYPE_INFO(typename unified_stream_type::ccl_native_t); API_CLASS_TYPE_INFO(typename unified_event_type::ccl_native_t); diff --git a/include/oneapi/ccl/ccl_device_types.hpp b/include/oneapi/ccl/device_types.hpp similarity index 78% rename from include/oneapi/ccl/ccl_device_types.hpp rename to include/oneapi/ccl/device_types.hpp index 4a6bcb41e..09ad9b0c9 100644 --- a/include/oneapi/ccl/ccl_device_types.hpp +++ b/include/oneapi/ccl/device_types.hpp @@ -24,22 +24,6 @@ namespace ccl { * Push the following code into something similar with 'ccl_device_types.hpp' */ -/** Device topology group. */ -typedef enum { - device_group = 0, - thread_group = 1, - process_group = 2, - - ccl_topology_group_last_value -} ccl_topology_group_t; - -enum device_topology_type { undetermined = -1, ring, a2a, last_class_value }; - -// TODO: tmp mapping -#define ring_algo_class device_topology_type::ring -#define a2a_algo_class device_topology_type::a2a -#define ccl_topology_class_last_value device_topology_type::last_class_value - using process_id = size_t; using host_id = std::string; @@ -52,18 +36,18 @@ using cluster_aggregated_device_mask_t = std::map::max(); //TODO + //TODO implement class instead using device_index_type = std::tuple; enum device_index_enum { driver_index_id, device_index_id, subdevice_index_id }; std::string to_string(const device_index_type& device_id); device_index_type from_string(const std::string& device_id_str); -using device_indices_t = std::multiset; -using process_device_indices_t = std::map; -using cluster_device_indices_t = std::map; +using device_indices_type = std::multiset; +using process_device_indices_type = std::map; +using cluster_device_indices_type = std::map; -struct empty_t{ -}; +struct empty_t {}; template struct backend_info {}; @@ -71,7 +55,7 @@ template struct generic_device_type {}; template -struct generic_device_context_type {}; +struct generic_context_type {}; template struct generic_platform_type {}; diff --git a/include/oneapi/ccl/environment.hpp b/include/oneapi/ccl/environment.hpp new file mode 100644 index 000000000..bf5c4193e --- /dev/null +++ b/include/oneapi/ccl/environment.hpp @@ -0,0 +1,301 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +#include +#include +#include +#include + +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/type_traits.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" +#include "oneapi/ccl/coll_attr.hpp" + +#include "oneapi/ccl/comm_attr_ids.hpp" +#include "oneapi/ccl/comm_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_attr.hpp" + +#include "oneapi/ccl/comm_split_attr_ids.hpp" +#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_split_attr.hpp" + +#include "oneapi/ccl/context_attr_ids.hpp" +#include "oneapi/ccl/context_attr_ids_traits.hpp" +#include "oneapi/ccl/context.hpp" + +#include "oneapi/ccl/datatype_attr_ids.hpp" +#include "oneapi/ccl/datatype_attr_ids_traits.hpp" +#include "oneapi/ccl/datatype_attr.hpp" + +#include "oneapi/ccl/device_attr_ids.hpp" +#include "oneapi/ccl/device_attr_ids_traits.hpp" +#include "oneapi/ccl/device.hpp" + +#include "oneapi/ccl/init_attr_ids.hpp" +#include "oneapi/ccl/init_attr_ids_traits.hpp" +#include "oneapi/ccl/init_attr.hpp" + +#include "oneapi/ccl/kvs_attr_ids.hpp" +#include "oneapi/ccl/kvs_attr_ids_traits.hpp" +#include "oneapi/ccl/kvs_attr.hpp" + +#include "oneapi/ccl/kvs.hpp" + +#include "oneapi/ccl/event.hpp" + +#include "oneapi/ccl/stream_attr_ids.hpp" +#include "oneapi/ccl/stream_attr_ids_traits.hpp" +#include "oneapi/ccl/stream.hpp" + +#include "oneapi/ccl/communicator.hpp" + +#include "oneapi/ccl/exception.hpp" + +namespace ccl { + +namespace detail { + +/** + * CCL environment singleton + */ +class environment { +public: + ~environment(); + + /** + * Retrieves the unique environment object + * and makes the first-time initialization of CCL library + */ + static environment& instance(); + + static ccl::library_version get_library_version(); + + template + static init_attr create_init_attr(attr_value_pair_t&&... avps) { + auto init_create_attr = create_postponed_api_type(); + int expander[]{ (init_create_attr.template set(avps.val()), + 0)... }; + (void)expander; + return init_create_attr; + } + + template + static coll_attribute_type create_operation_attr(attr_value_pair_t&&... avps) { + auto op_attr = create_postponed_api_type(); + int expander[]{ (op_attr.template set(avps.val()), 0)... }; + (void)expander; + return op_attr; + } + + /******************** DATATYPE ********************/ + + template + static datatype_attr create_datatype_attr(attr_value_pair_t&&... avps) { + static_assert(sizeof...(avps) > 0, "At least one argument must be specified"); + auto attr = create_postponed_api_type(); + int expander[]{ (attr.template set(avps.val()), 0)... }; + (void)expander; + return attr; + } + + ccl::datatype register_datatype(const datatype_attr& attr); + void deregister_datatype(ccl::datatype dtype); + size_t get_datatype_size(ccl::datatype dtype) const; + + /******************** KVS ********************/ + + template + static kvs_attr create_kvs_attr(attr_value_pair_t&&... avps) { + auto kvs_create_attr = create_postponed_api_type(); + int expander[]{ (kvs_create_attr.template set(avps.val()), + 0)... }; + (void)expander; + return kvs_create_attr; + } + + shared_ptr_class create_main_kvs(const kvs_attr& attr) const; + shared_ptr_class create_kvs(const kvs::address_type& addr, const kvs_attr& attr) const; + + /******************** DEVICE ********************/ + + device create_device(empty_t empty) const; + + template ()>::type> + device create_device(native_device_type&& native_device) const; + + template + device create_device_from_attr(typename unified_device_type::ccl_native_t dev, + attr_value_pair_t&&... avps) const { + device str = create_postponed_api_type(dev); + int expander[]{ (str.template set(avps.val()), 0)... }; + (void)expander; + str.build_from_params(); + return str; + } + + /******************** CONTEXT ********************/ + + context create_context(empty_t empty) const; + + template < + class native_device_contex_type, + class = typename std::enable_if()>::type> + context create_context(native_device_contex_type&& native_context) const; + + template + context create_context_from_attr(typename unified_context_type::ccl_native_t ctx, + attr_value_pair_t&&... avps) const { + context str = create_postponed_api_type(ctx); + int expander[]{ (str.template set(avps.val()), 0)... }; + (void)expander; + str.build_from_params(); + return str; + } + + /******************** EVENT ********************/ + + template ()>::type> + event create_event(event_type& native_event) { + return event::create_from_native(native_event); + } + + template ()>::type> + event create_event(event_handle_type& native_event_handle, event::context_t& context) { + return event::create_from_native(native_event_handle, context); + } + + /******************** STREAM ********************/ + + template ()>::type> + stream create_stream(native_stream_type& native_stream); + + template ()>::type> + stream create_stream(native_stream_type& native_stream, native_context_type& native_ctx); + + template + stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, + attr_value_pair_t&&... avps) { + stream str = create_stream(device); + int expander[]{ (str.template set(avps.val()), 0)... }; + (void)expander; + str.build_from_params(); + return str; + } + + template + stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, + typename unified_context_type::ccl_native_t context, + attr_value_pair_t&&... avps) { + stream str = create_stream(device, context); + int expander[]{ (str.template set(avps.val()), 0)... }; + (void)expander; + str.build_from_params(); + return str; + } + + /******************** COMMUNICATOR ********************/ + +#ifdef CCL_ENABLE_SYCL + communicator create_single_device_communicator(int comm_size, + int rank, + const cl::sycl::device& device, + const cl::sycl::context& context, + shared_ptr_class kvs) const; +#endif + + template + static comm_split_attr create_comm_split_attr(attr_value_pair_t&&... avps) { + auto split_attr = create_postponed_api_type(); + int expander[]{ (split_attr.template set(avps.val()), 0)... }; + (void)expander; + return split_attr; + } + + template + static comm_attr create_comm_attr(attr_value_pair_t&&... avps) { + auto comm_create_attr = create_postponed_api_type(); + int expander[]{ (comm_create_attr.template set(avps.val()), + 0)... }; + (void)expander; + return comm_create_attr; + } + + communicator create_communicator(const comm_attr& attr) const; + communicator create_communicator(size_t size, + shared_ptr_class kvs, + const comm_attr& attr) const; + communicator create_communicator(size_t size, + int rank, + shared_ptr_class kvs, + const comm_attr& attr) const; + + template + vector_class create_communicators(int comm_size, + const vector_class& local_devices, + const ContextType& context, + shared_ptr_class kvs, + const comm_attr& attr) const; + + template + vector_class create_communicators( + int comm_size, + const vector_class>& local_rank_device_map, + const ContextType& context, + shared_ptr_class kvs, + const comm_attr& attr) const; + + template + vector_class create_communicators( + int comm_size, + const map_class& local_rank_device_map, + const ContextType& context, + shared_ptr_class kvs, + const comm_attr& attr) const; + + vector_class split_communicators( + const vector_class>& attrs) const; + +private: + environment(); + + template + static ccl_api_type create_postponed_api_type(args_type... args) { + auto version = get_library_version(); + return ccl_api_type(std::forward(args)..., version); + } + + stream create_stream(typename unified_device_type::ccl_native_t device); + + stream create_stream(typename unified_device_type::ccl_native_t device, + typename unified_context_type::ccl_native_t context); +}; + +} // namespace detail + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_event.hpp b/include/oneapi/ccl/event.hpp similarity index 88% rename from include/oneapi/ccl/ccl_event.hpp rename to include/oneapi/ccl/event.hpp index ea7339572..61e49311a 100644 --- a/include/oneapi/ccl/ccl_event.hpp +++ b/include/oneapi/ccl/event.hpp @@ -20,9 +20,14 @@ #endif namespace ccl { +namespace detail { +class environment; +} class event_impl; +namespace v1 { + /** * event's interface that allows users to track communication operation progress */ @@ -41,6 +46,8 @@ class event : public ccl_api_base_movable()) {} // }; -class unimplemented : public ccl::exception { +class unimplemented : public exception { public: - unimplemented(const std::string &domain, const std::string &function, + unimplemented(const std::string &domain, + const std::string &function, const std::string &info = "") - : ccl::exception(domain, function, "function is not implemented " + info) {} + : exception(domain, function, "function is not implemented " + info) {} }; -class unsupported : public ccl::exception { +class unsupported : public exception { public: - unsupported(const std::string &domain, const std::string &function, - const std::string &info = "") - : ccl::exception(domain, function, "function is not supported " + info) {} + unsupported(const std::string &domain, + const std::string &function, + const std::string &info = "") + : exception(domain, function, "function is not supported " + info) {} }; -} +} // namespace v1 + +using v1::exception; +using v1::invalid_argument; +using v1::host_bad_alloc; +// using v1::device_bad_alloc; +using v1::unimplemented; +using v1::unsupported; + +} // namespace ccl diff --git a/include/oneapi/ccl/init_attr.hpp b/include/oneapi/ccl/init_attr.hpp new file mode 100644 index 000000000..39e801fb4 --- /dev/null +++ b/include/oneapi/ccl/init_attr.hpp @@ -0,0 +1,96 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { +namespace detail { +class environment; +} + +class init_attr_impl; + +namespace v1 { + +struct ccl_empty_attr; + +class init_attr + : public ccl_api_base_copyable { +public: + using base_t = ccl_api_base_copyable; + + /** + * Declare PIMPL type + */ + using impl_value_t = typename base_t::impl_value_t; + + /** + * Declare implementation type + */ + using impl_t = typename impl_value_t::element_type; + + init_attr& operator=(const init_attr& src); + init_attr& operator=(init_attr&& src); + init_attr(init_attr&& src); + init_attr(const init_attr& src); + ~init_attr() noexcept; + + /** + * Set specific value for selft attribute by @attrId. + * Previous attibute value would be returned + */ + template ()>::return_type*/> + Value set(const Value& v); + + /** + * Get specific attribute value by @attrId + */ + template + const typename detail::ccl_api_type_attr_traits::return_type& get() const; + +private: + friend class ccl::detail::environment; + friend struct ccl::ccl_empty_attr; + init_attr(const typename detail::ccl_api_type_attr_traits::return_type& + version); +}; + +/** + * Declare extern empty attributes + */ +extern init_attr default_init_attr; + +/** + * Fabric helpers + */ +template +constexpr auto attr_val(value_type v) -> detail::attr_value_triple { + return detail::attr_value_triple(v); +} + +} // namespace v1 + +using v1::init_attr; +using v1::default_init_attr; +using v1::attr_val; + +} // namespace ccl diff --git a/include/oneapi/ccl/init_attr_ids.hpp b/include/oneapi/ccl/init_attr_ids.hpp new file mode 100644 index 000000000..6196fde4a --- /dev/null +++ b/include/oneapi/ccl/init_attr_ids.hpp @@ -0,0 +1,36 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace v1 { + +enum class init_attr_id : int { + version, + + last_value +}; + +} // namespace v1 + +using v1::init_attr_id; + +} // namespace ccl diff --git a/include/oneapi/ccl/init_attr_ids_traits.hpp b/include/oneapi/ccl/init_attr_ids_traits.hpp new file mode 100644 index 000000000..5b36bb2e2 --- /dev/null +++ b/include/oneapi/ccl/init_attr_ids_traits.hpp @@ -0,0 +1,34 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace detail { + +template <> +struct ccl_api_type_attr_traits { + using type = ccl::library_version; + using return_type = type; +}; + +} // namespace detail + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_kvs.hpp b/include/oneapi/ccl/kvs.hpp similarity index 76% rename from include/oneapi/ccl/ccl_kvs.hpp rename to include/oneapi/ccl/kvs.hpp index fb97c44df..45de27595 100644 --- a/include/oneapi/ccl/ccl_kvs.hpp +++ b/include/oneapi/ccl/kvs.hpp @@ -20,18 +20,23 @@ #endif namespace ccl { +namespace detail { +class environment; +} + +class kvs_impl; + +namespace v1 { class CCL_API kvs_interface { public: - virtual ~kvs_interface() = default; - virtual vector_class get(const string_class& key) const = 0; + virtual vector_class get(const string_class& key) = 0; - virtual void set(const string_class& key, const vector_class& data) const = 0; + virtual void set(const string_class& key, const vector_class& data) = 0; }; -class kvs_impl; class CCL_API kvs final : public kvs_interface { public: static constexpr size_t address_max_size = 256; @@ -41,18 +46,24 @@ class CCL_API kvs final : public kvs_interface { address_type get_address() const; - vector_class get(const string_class& key) const override; + vector_class get(const string_class& key) override; - void set(const string_class& key, const vector_class& data) const override; + void set(const string_class& key, const vector_class& data) override; private: - friend class environment; + friend class ccl::detail::environment; - kvs(); - kvs(const address_type& addr); + kvs(const kvs_attr& attr); + kvs(const address_type& addr, const kvs_attr& attr); const kvs_impl& get_impl(); address_type addr; unique_ptr_class pimpl; }; + +} // namespace v1 + +using v1::kvs_interface; +using v1::kvs; + } // namespace ccl diff --git a/include/oneapi/ccl/kvs_attr.hpp b/include/oneapi/ccl/kvs_attr.hpp new file mode 100644 index 000000000..253b13dd0 --- /dev/null +++ b/include/oneapi/ccl/kvs_attr.hpp @@ -0,0 +1,98 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { +namespace detail { +class environment; +} + +class ccl_kvs_attr_impl; + +namespace v1 { + +struct ccl_empty_attr; + +/** + * kvsunicator attributes + */ +class kvs_attr + : public ccl_api_base_copyable { +public: + using base_t = ccl_api_base_copyable; + + /** + * Declare PIMPL type + */ + using impl_value_t = typename base_t::impl_value_t; + + /** + * Declare implementation type + */ + using impl_t = typename impl_value_t::element_type; + + kvs_attr& operator=(const kvs_attr& src); + kvs_attr& operator=(kvs_attr&& src); + kvs_attr(kvs_attr&& src); + kvs_attr(const kvs_attr& src); + kvs_attr(ccl_empty_attr); + ~kvs_attr() noexcept; + + /** + * Set specific value for selft attribute by @attrId. + * Previous attibute value would be returned + */ + template ()>::type*/> + Value set(const Value& v); + + /** + * Get specific attribute value by @attrId + */ + template + const typename detail::ccl_api_type_attr_traits::type& get() const; + + template + bool is_valid() const noexcept; + +private: + friend class ccl::detail::environment; + friend struct ccl::v1::ccl_empty_attr; + + kvs_attr(const typename detail::ccl_api_type_attr_traits::return_type& + version); +}; + +extern kvs_attr default_kvs_attr; + +template +constexpr auto attr_val(value_type v) -> detail::attr_value_triple { + return detail::attr_value_triple(v); +} + +} // namespace v1 + +using v1::kvs_attr; +using v1::default_kvs_attr; +using v1::attr_val; + +} // namespace ccl diff --git a/include/oneapi/ccl/kvs_attr_ids.hpp b/include/oneapi/ccl/kvs_attr_ids.hpp new file mode 100644 index 000000000..f753a2f76 --- /dev/null +++ b/include/oneapi/ccl/kvs_attr_ids.hpp @@ -0,0 +1,36 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace v1 { + +enum class kvs_attr_id : int { + version, + + last_value +}; + +} // namespace v1 + +using v1::kvs_attr_id; + +} // namespace ccl diff --git a/include/oneapi/ccl/kvs_attr_ids_traits.hpp b/include/oneapi/ccl/kvs_attr_ids_traits.hpp new file mode 100644 index 000000000..281955994 --- /dev/null +++ b/include/oneapi/ccl/kvs_attr_ids_traits.hpp @@ -0,0 +1,34 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifndef CCL_PRODUCT_FULL +#error "Do not include this file directly. Please include 'ccl.hpp'" +#endif + +namespace ccl { + +namespace detail { + +template <> +struct ccl_api_type_attr_traits { + using type = ccl::library_version; + using return_type = type; +}; + +} // namespace detail + +} // namespace ccl diff --git a/include/oneapi/ccl/lp_types.hpp b/include/oneapi/ccl/lp_types.hpp new file mode 100644 index 000000000..59109e3e6 --- /dev/null +++ b/include/oneapi/ccl/lp_types.hpp @@ -0,0 +1,72 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include + +namespace ccl { + +namespace preview { + +// struct float16 { +// constexpr float16() : data(0) {} +// constexpr float16(uint16_t v) : data(v) {} +// uint16_t data; + +// friend std::ostream& operator<<(std::ostream& out, const float16& v) { +// out << v.data; +// return out; +// } + +// friend bool operator==(const float16& v1, const float16& v2) { +// return (v1.data == v2.data) ? true : false; +// } + +// friend bool operator!=(const float16& v1, const float16& v2) { +// return !(v1 == v2); +// } + +// } __attribute__((packed)); + +} // namespace preview + +namespace v1 { + +struct bfloat16 { + constexpr bfloat16() : data(0) {} + constexpr bfloat16(uint16_t v) : data(v) {} + uint16_t data; + + friend std::ostream& operator<<(std::ostream& out, const bfloat16& v) { + out << v.data; + return out; + } + + friend bool operator==(const bfloat16& v1, const bfloat16& v2) { + return (v1.data == v2.data) ? true : false; + } + + friend bool operator!=(const bfloat16& v1, const bfloat16& v2) { + return !(v1 == v2); + } + +} __attribute__((packed)); + +} // namespace v1 + +using v1::bfloat16; + +} // namespace ccl diff --git a/include/oneapi/ccl/native_device_api/empty/context.hpp b/include/oneapi/ccl/native_device_api/empty/context.hpp index 92291cea7..d326b2392 100644 --- a/include/oneapi/ccl/native_device_api/empty/context.hpp +++ b/include/oneapi/ccl/native_device_api/empty/context.hpp @@ -16,6 +16,5 @@ #pragma once namespace native { -struct ccl_context { -}; -} +struct ccl_context {}; +} // namespace native diff --git a/include/oneapi/ccl/native_device_api/empty/device.hpp b/include/oneapi/ccl/native_device_api/empty/device.hpp index 3970fabed..7105be02a 100644 --- a/include/oneapi/ccl/native_device_api/empty/device.hpp +++ b/include/oneapi/ccl/native_device_api/empty/device.hpp @@ -21,4 +21,4 @@ struct ccl_device { using device_event = ccl_device_event; using device_queue = ccl_device_queue; }; -} +} // namespace native diff --git a/include/oneapi/ccl/native_device_api/empty/event.hpp b/include/oneapi/ccl/native_device_api/empty/event.hpp index c0d098acb..1e8edd786 100644 --- a/include/oneapi/ccl/native_device_api/empty/event.hpp +++ b/include/oneapi/ccl/native_device_api/empty/event.hpp @@ -16,6 +16,5 @@ #pragma once namespace native { -struct ccl_device_queue { -}; -} +struct ccl_device_queue {}; +} // namespace native diff --git a/include/oneapi/ccl/native_device_api/empty/export.hpp b/include/oneapi/ccl/native_device_api/empty/export.hpp index bfcad4b3c..9e66d3138 100644 --- a/include/oneapi/ccl/native_device_api/empty/export.hpp +++ b/include/oneapi/ccl/native_device_api/empty/export.hpp @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" +#include "oneapi/ccl/types.hpp" #define CL_BACKEND_TYPE ccl::cl_backend_type::empty_backend @@ -23,16 +23,16 @@ #include "oneapi/ccl/native_device_api/empty/platform.hpp" #include "oneapi/ccl/native_device_api/empty/primitives.hpp" - -namespace ccl -{ +namespace ccl { template <> struct backend_info { CCL_API static constexpr ccl::cl_backend_type type() { - return CL_BACKEND_TYPE; } + return CL_BACKEND_TYPE; + } CCL_API static constexpr const char* name() { - return "CL_BACKEND_UNAVAILABLE"; } + return "CL_BACKEND_UNAVAILABLE"; + } }; template <> @@ -41,20 +41,25 @@ struct generic_device_type { using impl_t = native::ccl_device; using ccl_native_t = std::shared_ptr; - template - generic_device_type(T&& not_used) {(void)not_used;}; + template + generic_device_type(T&& not_used) { + (void)not_used; + }; void get_id() const noexcept; - ccl_native_t get() noexcept; + ccl_native_t& get() noexcept; + const ccl_native_t& get() const noexcept; }; template <> -struct generic_device_context_type { +struct generic_context_type { using handle_t = empty_t; using impl_t = native::ccl_context; using ccl_native_t = std::shared_ptr; - template - generic_device_context_type(T&& not_used) {(void)not_used;}; + template + generic_context_type(T&& not_used) { + (void)not_used; + }; ccl_native_t get() noexcept; const ccl_native_t& get() const noexcept; @@ -92,4 +97,4 @@ struct generic_event_type { ccl_native_t get() noexcept; const ccl_native_t& get() const noexcept; }; -} +} // namespace ccl diff --git a/include/oneapi/ccl/native_device_api/empty/platform.hpp b/include/oneapi/ccl/native_device_api/empty/platform.hpp index 70a27b014..f9fcecbd7 100644 --- a/include/oneapi/ccl/native_device_api/empty/platform.hpp +++ b/include/oneapi/ccl/native_device_api/empty/platform.hpp @@ -16,6 +16,5 @@ #pragma once namespace native { -struct ccl_device_platform { -}; -} +struct ccl_device_platform {}; +} // namespace native diff --git a/include/oneapi/ccl/native_device_api/empty/primitives.hpp b/include/oneapi/ccl/native_device_api/empty/primitives.hpp index 769e06a23..8bd66a66d 100644 --- a/include/oneapi/ccl/native_device_api/empty/primitives.hpp +++ b/include/oneapi/ccl/native_device_api/empty/primitives.hpp @@ -16,8 +16,6 @@ #pragma once namespace native { -struct ccl_device_event { -}; -struct ccl_device_queue { -}; -} +struct ccl_device_event {}; +struct ccl_device_queue {}; +} // namespace native diff --git a/include/oneapi/ccl/native_device_api/empty/queue.hpp b/include/oneapi/ccl/native_device_api/empty/queue.hpp index c0d098acb..1e8edd786 100644 --- a/include/oneapi/ccl/native_device_api/empty/queue.hpp +++ b/include/oneapi/ccl/native_device_api/empty/queue.hpp @@ -16,6 +16,5 @@ #pragma once namespace native { -struct ccl_device_queue { -}; -} +struct ccl_device_queue {}; +} // namespace native diff --git a/include/oneapi/ccl/native_device_api/export_api.hpp b/include/oneapi/ccl/native_device_api/export_api.hpp index 3a1a79fd8..17d015351 100644 --- a/include/oneapi/ccl/native_device_api/export_api.hpp +++ b/include/oneapi/ccl/native_device_api/export_api.hpp @@ -14,12 +14,12 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_config.h" +#include "oneapi/ccl/config.h" #ifdef CCL_ENABLE_SYCL - #ifdef MULTI_GPU_SUPPORT - #include "sycl_l0/export.hpp" - /* +#ifdef MULTI_GPU_SUPPORT +#include "sycl_l0/export.hpp" +/* #include "oneapi/ccl/native_device_api/l0/base.hpp" #include "oneapi/ccl/native_device_api/l0/base_impl.hpp" @@ -32,15 +32,15 @@ #include "oneapi/ccl/native_device_api/l0/driver.hpp" #include "oneapi/ccl/native_device_api/l0/platform.hpp" */ - #else - #include "sycl/export.hpp" - #endif #else - #ifdef MULTI_GPU_SUPPORT - #include "l0/export.hpp" - #else - #include "empty/export.hpp" - #endif +#include "sycl/export.hpp" +#endif +#else +#ifdef MULTI_GPU_SUPPORT +#include "l0/export.hpp" +#else +#include "empty/export.hpp" +#endif #endif #ifndef CL_BACKEND_TYPE @@ -49,10 +49,10 @@ namespace ccl { using backend_traits = backend_info; using unified_device_type = generic_device_type; -using unified_device_context_type = generic_device_context_type; +using unified_context_type = generic_context_type; using unified_platform_type = generic_platform_type; using unified_stream_type = generic_stream_type; using unified_event_type = generic_event_type; -} +} // namespace ccl #include "interop_utils.hpp" diff --git a/include/oneapi/ccl/native_device_api/interop_utils.hpp b/include/oneapi/ccl/native_device_api/interop_utils.hpp index ccaae1b13..15587b6b8 100644 --- a/include/oneapi/ccl/native_device_api/interop_utils.hpp +++ b/include/oneapi/ccl/native_device_api/interop_utils.hpp @@ -14,14 +14,14 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_type_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/type_traits.hpp" #ifdef CCL_ENABLE_SYCL #include #endif namespace native { -namespace details { +namespace detail { #ifdef CCL_ENABLE_SYCL size_t get_sycl_device_id(const cl::sycl::device& dev); @@ -31,24 +31,28 @@ std::string usm_to_string(cl::sycl::usm::alloc val); enum usm_support_mode { prohibited = 0, direct, shared, need_conversion, last_value }; std::string to_string(usm_support_mode val); -using assoc_retult = std::tuple; +using assoc_result = std::tuple; enum assoc_result_index { SUPPORT_MODE = 0, POINTER_VALUE, ERROR_CAUSE }; #if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -assoc_retult check_assoc_device_memory(const void* mem, +assoc_result check_assoc_device_memory(const void* mem, const ccl::unified_device_type::ccl_native_t& device, - const ccl::unified_device_context_type::ccl_native_t& ctx); + const ccl::unified_context_type::ccl_native_t& ctx); + +usm_support_mode check_assoc_device_memory(const std::vector& mems, + const ccl::unified_device_type::ccl_native_t& device, + const ccl::unified_context_type::ccl_native_t& ctx); #endif //defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -std::string to_string(const assoc_retult& res); +std::string to_string(const assoc_result& res); #if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) template -using multiple_assoc_result = std::array; +using multiple_assoc_result = std::array; template auto check_multiple_assoc_device_memory(const ccl::unified_device_type::ccl_native_t& device, - const ccl::unified_device_context_type::ccl_native_t& ctx, + const ccl::unified_context_type::ccl_native_t& ctx, const mem_type*... mem) -> multiple_assoc_result { multiple_assoc_result ret{ check_assoc_device_memory(mem, device, ctx)... }; @@ -64,5 +68,5 @@ std::string to_string(const multiple_assoc_result& res) { return ss.str(); } #endif //defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -} // namespace details +} // namespace detail } // namespace native diff --git a/include/oneapi/ccl/native_device_api/l0/base.hpp b/include/oneapi/ccl/native_device_api/l0/base.hpp index c7429a056..18e8721fa 100644 --- a/include/oneapi/ccl/native_device_api/l0/base.hpp +++ b/include/oneapi/ccl/native_device_api/l0/base.hpp @@ -23,9 +23,9 @@ #include #ifndef UT -//#include "oneapi/ccl/ccl_types.hpp" -//#include "oneapi/ccl/ccl_type_traits.hpp" -#include "oneapi/ccl/ccl_types.hpp" +//#include "oneapi/ccl/types.hpp" +//#include "oneapi/ccl/type_traits.hpp" +#include "oneapi/ccl/types.hpp" #endif namespace native { diff --git a/include/oneapi/ccl/native_device_api/l0/base_impl.hpp b/include/oneapi/ccl/native_device_api/l0/base_impl.hpp index 19447d44b..61a951030 100644 --- a/include/oneapi/ccl/native_device_api/l0/base_impl.hpp +++ b/include/oneapi/ccl/native_device_api/l0/base_impl.hpp @@ -59,7 +59,7 @@ template cl_base::~cl_base() noexcept { auto lock = owner.lock(); // auto ctx = context.lock(); ctx->get(); - ze_context_handle_t ctxtmp; + ze_context_handle_t ctxtmp = nullptr; if (lock) { lock->on_delete(handle, ctxtmp); } @@ -226,7 +226,7 @@ indexed_storage merge_indexed_values(const IndexedContainer& indexes } template -indexed_storage collect_indexed_data(const ccl::device_indices_t& indexes, +indexed_storage collect_indexed_data(const ccl::device_indices_type& indexes, std::vector& collected_values, value_type_index_extractor functor) { indexed_storage ret; diff --git a/include/oneapi/ccl/native_device_api/l0/context.hpp b/include/oneapi/ccl/native_device_api/l0/context.hpp index ea108184f..c5fb08f28 100644 --- a/include/oneapi/ccl/native_device_api/l0/context.hpp +++ b/include/oneapi/ccl/native_device_api/l0/context.hpp @@ -41,16 +41,29 @@ struct ccl_context : public cl_base get_ptr() { return this->shared_from_this(); } - }; -struct ccl_context_holder -{ - std::map>> map_context; +class context_array_t { +public: + using value_type = std::vector>; + using context_array_accessor = detail::unique_accessor; - ze_context_handle_t get() { - return nullptr; - } + context_array_accessor access(); + +private: + std::mutex m; + value_type contexts; }; +struct ccl_context_holder { + ze_context_handle_t get(); + std::shared_ptr emplace(ccl_device_driver* driver, + std::shared_ptr&& ctx); + context_array_t& get_context_storage(ccl_device_driver* driver); + +private: + std::mutex m; + std::map drivers_context; +}; +using ccl_driver_context_ptr = std::shared_ptr; } // namespace native diff --git a/include/oneapi/ccl/native_device_api/l0/device.hpp b/include/oneapi/ccl/native_device_api/l0/device.hpp index 635989847..a6ccafc5b 100644 --- a/include/oneapi/ccl/native_device_api/l0/device.hpp +++ b/include/oneapi/ccl/native_device_api/l0/device.hpp @@ -27,9 +27,9 @@ struct ccl_subdevice; struct ccl_device; struct ccl_context; -details::cross_device_rating property_p2p_rating_calculator(const ccl_device& lhs, - const ccl_device& rhs, - size_t weight); +detail::cross_device_rating property_p2p_rating_calculator(const ccl_device& lhs, + const ccl_device& rhs, + size_t weight); // TODO not thread-safe!!! struct ccl_device : public cl_base, @@ -75,11 +75,11 @@ struct ccl_device : public cl_base create( handle_t h, owner_ptr_t&& driver, - const ccl::device_indices_t& indexes = ccl::device_indices_t()); + const ccl::device_indices_type& indexes = ccl::device_indices_type()); std::shared_ptr get_ptr() { return this->shared_from_this(); @@ -89,7 +89,7 @@ struct ccl_device : public cl_baseshared_from_this(); } - context_storage_type get_device_contexts(); + context_storage_type get_contexts(); sub_devices_container_type& get_subdevices(); const sub_devices_container_type& get_subdevices() const; @@ -108,43 +108,59 @@ struct ccl_device : public cl_base device_memory alloc_memory( size_t count, - size_t alignment, std::shared_ptr ctx, + size_t alignment, + std::shared_ptr ctx, const ze_device_mem_alloc_desc_t& mem_descr = get_default_mem_alloc_desc(), const ze_host_mem_alloc_desc_t& host_descr = get_default_host_alloc_desc()) { - return device_memory(reinterpret_cast(device_alloc_memory( - count * sizeof(elem_t), alignment, mem_descr, host_descr, ctx)), - count, - get_ptr(), ctx); + return device_memory( + reinterpret_cast( + device_alloc_memory(count * sizeof(elem_t), alignment, mem_descr, host_descr, ctx)), + count, + get_ptr(), + ctx); } template device_memory_ptr alloc_shared_memory( size_t count, size_t alignment, - const ze_host_mem_alloc_desc_t& host_desc, std::shared_ptr ctx, + const ze_host_mem_alloc_desc_t& host_desc, + std::shared_ptr ctx, const ze_device_mem_alloc_desc_t& mem_descr = get_default_mem_alloc_desc()) { return std::make_shared>( reinterpret_cast(device_alloc_shared_memory( count * sizeof(elem_t), alignment, host_desc, mem_descr, ctx)), count, - get_ptr(), ctx); + get_ptr(), + ctx); } - device_ipc_memory_handle create_ipc_memory_handle(void* device_mem_ptr, std::shared_ptr ctx); - std::shared_ptr create_shared_ipc_memory_handle(void* device_mem_ptr, std::shared_ptr ctx); + device_ipc_memory_handle create_ipc_memory_handle(void* device_mem_ptr, + std::shared_ptr ctx); + std::shared_ptr create_shared_ipc_memory_handle( + void* device_mem_ptr, + std::shared_ptr ctx); - device_ipc_memory get_ipc_memory(std::shared_ptr&& handle, std::shared_ptr ctx); + device_ipc_memory get_ipc_memory(std::shared_ptr&& handle, + std::shared_ptr ctx); std::shared_ptr restore_shared_ipc_memory( - std::shared_ptr&& handle, std::shared_ptr ctx); + std::shared_ptr&& handle, + std::shared_ptr ctx); - device_queue create_cmd_queue(std::shared_ptr ctx, + device_queue create_cmd_queue( + std::shared_ptr ctx, const ze_command_queue_desc_t& properties = get_default_queue_desc()); - ze_fence_handle_t create_or_get_fence(const device_queue& queue, std::shared_ptr ctx); - device_queue& get_cmd_queue(const ze_command_queue_desc_t& properties, std::shared_ptr ctx); - device_cmd_list create_cmd_list(std::shared_ptr ctx, + device_queue_fence& get_fence(const device_queue& queue, std::shared_ptr ctx); + device_queue& get_cmd_queue(const ze_command_queue_desc_t& properties, + std::shared_ptr ctx); + device_cmd_list create_cmd_list( + std::shared_ptr ctx, const ze_command_list_desc_t& properties = get_default_list_desc()); - device_cmd_list& get_cmd_list(std::shared_ptr ctx, + device_cmd_list& get_cmd_list( + std::shared_ptr ctx, const ze_command_list_desc_t& properties = get_default_list_desc()); - device_module_ptr create_module(const ze_module_desc_t& descr, size_t hash, std::shared_ptr ctx); + device_module_ptr create_module(const ze_module_desc_t& descr, + size_t hash, + std::shared_ptr ctx); template bool is_own_memory(const device_memory& mem_handle) { @@ -154,7 +170,8 @@ struct ccl_device : public cl_base { }); + handle_t assoc_handle = + get_assoc_device_handle(mem_handle.handle, get_owner(), std::shared_ptr{}); return assoc_handle == handle; } @@ -170,7 +187,7 @@ struct ccl_device : public cl_base void on_delete(elem_t* mem_handle, ze_context_handle_t& ctx) { // TODO: ctx - device_free_memory(static_cast(mem_handle), std::shared_ptr { }); + device_free_memory(static_cast(mem_handle), std::shared_ptr{}); } // serialize/deserialize @@ -185,10 +202,13 @@ struct ccl_device : public cl_base& out, size_t from_pos, size_t expected_size) const; - std::shared_ptr get_default_device_context(); + std::shared_ptr get_default_context(); private: - ccl_device(handle_t h, owner_ptr_t&& parent, std::weak_ptr&& ctx, std::false_type); + ccl_device(handle_t h, + owner_ptr_t&& parent, + std::weak_ptr&& ctx, + std::false_type); void initialize_device_data(); void* device_alloc_memory(size_t bytes_count, size_t alignment, @@ -201,8 +221,9 @@ struct ccl_device : public cl_base ctx); - static handle_t get_assoc_device_handle(const void* ptr, const ccl_device_driver* driver, - std::shared_ptr ctx); + static handle_t get_assoc_device_handle(const void* ptr, + const ccl_device_driver* driver, + std::shared_ptr ctx); void device_free_memory(void* mem_ptr, std::shared_ptr ctx); //TODO shared mutex? @@ -210,6 +231,7 @@ struct ccl_device : public cl_base cmd_queus; std::map queue_fences; + std::mutex list_mutex; std::map cmd_lists; sub_devices_container_type sub_devices; diff --git a/include/oneapi/ccl/native_device_api/l0/driver.hpp b/include/oneapi/ccl/native_device_api/l0/driver.hpp index 5ee420051..60612c510 100644 --- a/include/oneapi/ccl/native_device_api/l0/driver.hpp +++ b/include/oneapi/ccl/native_device_api/l0/driver.hpp @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once + #include #include #include @@ -27,8 +28,9 @@ struct ccl_device_driver; struct ccl_device; struct ccl_context; struct ccl_context_holder; -struct ccl_device_driver : public cl_base, - std::enable_shared_from_this { +struct ccl_device_driver + : public cl_base, + std::enable_shared_from_this { friend std::ostream& operator<<(std::ostream&, const ccl_device_driver&); using base = cl_base; @@ -44,10 +46,13 @@ struct ccl_device_driver : public cl_base; using indexed_driver_handles = indexed_storage; - ccl_device_driver(handle_t h, uint32_t id, owner_ptr_t&& platform, std::weak_ptr&& ctx); + ccl_device_driver(handle_t h, + uint32_t id, + owner_ptr_t&& platform, + std::weak_ptr&& ctx); static indexed_driver_handles get_handles( - const ccl::device_indices_t& requested_driver_indexes = ccl::device_indices_t()); + const ccl::device_indices_type& requested_driver_indexes = ccl::device_indices_type()); static std::shared_ptr create( handle_t h, uint32_t id, @@ -58,7 +63,7 @@ struct ccl_device_driver : public cl_base get_ptr() { return this->shared_from_this(); @@ -74,6 +79,7 @@ struct ccl_device_driver : public cl_base create_context(); + std::shared_ptr create_context_from_handle(ccl_context::handle_t); std::string to_string(const std::string& prefix = std::string()) const; @@ -92,8 +98,8 @@ struct ccl_device_driver : public cl_base struct backend_info { CCL_API static constexpr ccl::cl_backend_type type() { - return CL_BACKEND_TYPE; } + return CL_BACKEND_TYPE; + } CCL_API static constexpr const char* name() { - return "CL_INTEL_L0_BACKEND"; } + return "CL_INTEL_L0_BACKEND"; + } }; template <> @@ -38,25 +39,23 @@ struct generic_device_type { using ccl_native_t = std::shared_ptr; generic_device_type(device_index_type id); + generic_device_type(ccl_native_t dev); device_index_type get_id() const noexcept; - ccl_native_t get() noexcept; + ccl_native_t& get() noexcept; + const ccl_native_t& get() const noexcept; handle_t device; }; -#ifndef ze_context_handle_t -#define ze_context_handle_t void* -#endif - template <> -struct generic_device_context_type { +struct generic_context_type { using handle_t = ze_context_handle_t; using impl_t = native::ccl_context; using ccl_native_t = std::shared_ptr; - generic_device_context_type(); - generic_device_context_type(handle_t ctx); - ccl_native_t get() noexcept; + generic_context_type(); + generic_context_type(ccl_native_t ctx); + ccl_native_t& get() noexcept; const ccl_native_t& get() const noexcept; ccl_native_t context; @@ -66,20 +65,20 @@ template <> struct generic_platform_type { using handle_t = native::ccl_device_platform; using impl_t = handle_t; - using ccl_native_t = std::shared_ptr; + using ccl_native_t = impl_t; - ccl_native_t get() noexcept; + ccl_native_t& get() noexcept; const ccl_native_t& get() const noexcept; }; template <> struct generic_stream_type { using handle_t = ze_command_queue_handle_t; - using impl_t = handle_t; - using ccl_native_t = std::shared_ptr; + using impl_t = native::ccl_device::device_queue; + using ccl_native_t = std::shared_ptr; generic_stream_type(handle_t q); - ccl_native_t get() noexcept; + ccl_native_t& get() noexcept; const ccl_native_t& get() const noexcept; ccl_native_t queue; @@ -92,7 +91,7 @@ struct generic_event_type { using ccl_native_t = std::shared_ptr; generic_event_type(handle_t e); - ccl_native_t get() noexcept; + ccl_native_t& get() noexcept; const ccl_native_t& get() const noexcept; ccl_native_t event; @@ -104,4 +103,4 @@ struct generic_event_type { API_CLASS_TYPE_INFO(native::ccl_device::device_queue); //API_CLASS_TYPE_INFO(ze_command_queue_handle_t); API_CLASS_TYPE_INFO(ze_event_handle_t); -} +} // namespace ccl diff --git a/include/oneapi/ccl/native_device_api/l0/platform.hpp b/include/oneapi/ccl/native_device_api/l0/platform.hpp index ba27309dc..4b625376a 100644 --- a/include/oneapi/ccl/native_device_api/l0/platform.hpp +++ b/include/oneapi/ccl/native_device_api/l0/platform.hpp @@ -28,7 +28,7 @@ struct ccl_device_platform : std::enable_shared_from_this { using context_storage_type = std::shared_ptr; //void init_drivers(const device_affinity_per_driver& affinities / * = device_affinity_per_driver()* /); - void init_drivers(const ccl::device_indices_t& indices = ccl::device_indices_t()); + void init_drivers(const ccl::device_indices_type& indices = ccl::device_indices_type()); std::shared_ptr get_ptr() { return this->shared_from_this(); @@ -50,12 +50,12 @@ struct ccl_device_platform : std::enable_shared_from_this { void on_delete(ccl_context::handle_t& context, ze_context_handle_t& ctx); static std::shared_ptr create( - const ccl::device_indices_t& indices = ccl::device_indices_t()); + const ccl::device_indices_type& indices = ccl::device_indices_type()); //static std::shared_ptr create(const device_affinity_per_driver& affinities); - details::adjacency_matrix calculate_device_access_metric( - const ccl::device_indices_t& indices = ccl::device_indices_t(), - details::p2p_rating_function func = details::binary_p2p_rating_calculator) const; + detail::adjacency_matrix calculate_device_access_metric( + const ccl::device_indices_type& indices = ccl::device_indices_type(), + detail::p2p_rating_function func = detail::binary_p2p_rating_calculator) const; private: ccl_device_platform(); diff --git a/include/oneapi/ccl/native_device_api/l0/primitives.hpp b/include/oneapi/ccl/native_device_api/l0/primitives.hpp index d55431b71..ddd61523b 100644 --- a/include/oneapi/ccl/native_device_api/l0/primitives.hpp +++ b/include/oneapi/ccl/native_device_api/l0/primitives.hpp @@ -34,6 +34,7 @@ std::string to_string(const ze_device_compute_properties_t& compute_properties, const std::string& prefix = std::string("\n")); std::string to_string(const ze_memory_allocation_properties_t& prop); std::string to_string(const ze_device_p2p_properties_t& properties); +std::string to_string(const ze_device_mem_alloc_desc_t& mem_descr); std::string to_string(const ze_ipc_mem_handle_t& handle); /** @@ -66,13 +67,18 @@ template using event = cl_base; template -struct memory/**/ : private cl_base { +struct memory /**/ : private cl_base { using base = cl_base; using base::get_owner; using base::get_ctx; using base::handle; - memory(elem_t* h, size_t count, std::weak_ptr&& owner, std::weak_ptr&& context); + memory(elem_t* h, + size_t count, + std::weak_ptr&& owner, + std::weak_ptr&& context); /** * Memory operations @@ -84,16 +90,18 @@ struct memory/**/ : private cl_base void enqueue_write_sync(const std::array& src); void enqueue_write_sync(const elem_t* src, size_t n); + void enqueue_write_sync(const elem_t* src, int n); // async - queue_fence enqueue_write_async(const std::vector& src, - queue& queue); + queue_fence enqueue_write_async( + const std::vector& src, + queue& queue); template - queue_fence enqueue_write_async(const std::array& src, - queue& queue); - queue_fence enqueue_write_async(const elem_t* src, - size_t n, - queue& queue); + queue_fence enqueue_write_async( + const std::array& src, + queue& queue); + queue_fence + enqueue_write_async(const elem_t* src, size_t n, queue& queue); // sync memory-copy read std::vector enqueue_read_sync(size_t requested_size = 0) const; diff --git a/include/oneapi/ccl/native_device_api/l0/primitives_impl.hpp b/include/oneapi/ccl/native_device_api/l0/primitives_impl.hpp index 7af8fef44..7a3294ecd 100644 --- a/include/oneapi/ccl/native_device_api/l0/primitives_impl.hpp +++ b/include/oneapi/ccl/native_device_api/l0/primitives_impl.hpp @@ -34,7 +34,10 @@ void copy_memory_to_device_sync_unsafe(void* dst, } template -memory::memory(elem_t* h, size_t count, std::weak_ptr&& owner, std::weak_ptr&& context) +memory::memory(elem_t* h, + size_t count, + std::weak_ptr&& owner, + std::weak_ptr&& context) : base(h, std::move(owner), std::move(context)), elem_count(count) {} @@ -117,6 +120,7 @@ void memory::enqueue_write_sync(const std::array& s throw std::runtime_error(std::string(__PRETTY_FUNCTION__) + "\n" + ex.what()); } } + template void memory::enqueue_write_sync(const elem_t* src, size_t src_elem_count) { if (!src) { @@ -142,6 +146,12 @@ void memory::enqueue_write_sync(const elem_t* src, size_t src_ } } +template +void memory::enqueue_write_sync(const elem_t* src, int src_elem_count) { + size_t elem_count = src_elem_count; + enqueue_write_sync(src, elem_count); +} + template std::vector memory::enqueue_read_sync( size_t src_elem_count /* = 0*/) const { diff --git a/include/oneapi/ccl/native_device_api/l0/subdevice.hpp b/include/oneapi/ccl/native_device_api/l0/subdevice.hpp index a7c800c91..35bfe08bb 100644 --- a/include/oneapi/ccl/native_device_api/l0/subdevice.hpp +++ b/include/oneapi/ccl/native_device_api/l0/subdevice.hpp @@ -32,13 +32,16 @@ struct ccl_subdevice : public ccl_device { friend std::ostream& operator<<(std::ostream&, const ccl_subdevice& node); - ccl_subdevice(handle_t h, owner_ptr_t&& device, base::owner_ptr_t&& driver, base::context_ptr_t&& ctx); + ccl_subdevice(handle_t h, + owner_ptr_t&& device, + base::owner_ptr_t&& driver, + base::context_ptr_t&& ctx); virtual ~ccl_subdevice(); // factory static indexed_handles get_handles( const ccl_device& device, - const ccl::device_indices_t& requested_indices = ccl::device_indices_t()); + const ccl::device_indices_type& requested_indices = ccl::device_indices_type()); static std::shared_ptr create(handle_t h, owner_ptr_t&& device, base::owner_ptr_t&& driver); @@ -64,7 +67,11 @@ struct ccl_subdevice : public ccl_device { ccl_device_platform& platform); private: - ccl_subdevice(handle_t h, owner_ptr_t&& device, base::owner_ptr_t&& driver, base::context_ptr_t&& ctx, std::false_type); + ccl_subdevice(handle_t h, + owner_ptr_t&& device, + base::owner_ptr_t&& driver, + base::context_ptr_t&& ctx, + std::false_type); void initialize_subdevice_data(); owner_ptr_t parent_device; }; diff --git a/include/oneapi/ccl/native_device_api/l0/utils.hpp b/include/oneapi/ccl/native_device_api/l0/utils.hpp index 4b95ab064..c89e9af82 100644 --- a/include/oneapi/ccl/native_device_api/l0/utils.hpp +++ b/include/oneapi/ccl/native_device_api/l0/utils.hpp @@ -16,13 +16,13 @@ #pragma once #include -#include "oneapi/ccl/ccl_types.hpp" -//#include "oneapi/ccl/ccl_type_traits.hpp" +#include "oneapi/ccl/types.hpp" +//#include "oneapi/ccl/type_traits.hpp" namespace native { struct ccl_device; -namespace details { +namespace detail { /* * Boolean matrix represents P2P device capable connectivity 'cross_device_rating' @@ -51,5 +51,19 @@ using p2p_rating_function = std::function; cross_device_rating binary_p2p_rating_calculator(const ccl_device& lhs, const ccl_device& rhs); -} // namespace details + +template +struct unique_accessor { + unique_accessor(Lock& mutex, Resource& storage) : lock(mutex), inner_data(storage) {} + unique_accessor(unique_accessor&& src) = default; + + Resource& get() { + return inner_data; + } + +private: + std::unique_lock lock; + Resource& inner_data; +}; +} // namespace detail } // namespace native diff --git a/include/oneapi/ccl/native_device_api/sycl/export.hpp b/include/oneapi/ccl/native_device_api/sycl/export.hpp index 510e41dbc..17b665e05 100644 --- a/include/oneapi/ccl/native_device_api/sycl/export.hpp +++ b/include/oneapi/ccl/native_device_api/sycl/export.hpp @@ -14,24 +14,25 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" +#include "oneapi/ccl/types.hpp" #define CL_BACKEND_TYPE ccl::cl_backend_type::dpcpp_sycl #include -namespace ccl -{ +namespace ccl { template <> struct backend_info { CCL_API static constexpr ccl::cl_backend_type type() { - return CL_BACKEND_TYPE; } + return CL_BACKEND_TYPE; + } CCL_API static constexpr const char* name() { - return "CL_DPCPP_BACKEND"; } + return "CL_DPCPP_BACKEND"; + } }; template <> struct generic_device_type { - using handle_t = cl_device_id;//cl::sycl::device; + using handle_t = cl_device_id; //cl::sycl::device; using impl_t = cl::sycl::device; using ccl_native_t = impl_t; @@ -40,18 +41,19 @@ struct generic_device_type { generic_device_type(const cl::sycl::device& device); device_index_type get_id() const; ccl_native_t& get() noexcept; + const ccl_native_t& get() const noexcept; cl::sycl::device device; }; template <> -struct generic_device_context_type { +struct generic_context_type { using handle_t = cl_context; using impl_t = cl::sycl::context; using ccl_native_t = impl_t; - generic_device_context_type(); - generic_device_context_type(ccl_native_t ctx); + generic_context_type(); + generic_context_type(ccl_native_t ctx); ccl_native_t& get() noexcept; const ccl_native_t& get() const noexcept; @@ -102,4 +104,4 @@ struct generic_event_type { API_CLASS_TYPE_INFO(cl_command_queue); API_CLASS_TYPE_INFO(cl_context); API_CLASS_TYPE_INFO(cl_event) -} +} // namespace ccl diff --git a/include/oneapi/ccl/native_device_api/sycl_l0/export.hpp b/include/oneapi/ccl/native_device_api/sycl_l0/export.hpp index cc8d49ac6..07d06eaba 100644 --- a/include/oneapi/ccl/native_device_api/sycl_l0/export.hpp +++ b/include/oneapi/ccl/native_device_api/sycl_l0/export.hpp @@ -14,24 +14,26 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" + +#include "oneapi/ccl/types.hpp" #define CL_BACKEND_TYPE ccl::cl_backend_type::dpcpp_sycl_l0 #include -namespace ccl -{ +namespace ccl { template <> struct backend_info { CCL_API static constexpr ccl::cl_backend_type type() { - return CL_BACKEND_TYPE; } + return CL_BACKEND_TYPE; + } CCL_API static constexpr const char* name() { - return "CL_DPCPP_POWERED_L0_BACKEND"; } + return "CL_DPCPP_POWERED_L0_BACKEND"; + } }; template <> struct generic_device_type { - using handle_t = cl_device_id;//cl::sycl::device; + using handle_t = cl_device_id; //cl::sycl::device; using impl_t = cl::sycl::device; using ccl_native_t = impl_t; @@ -40,18 +42,19 @@ struct generic_device_type { generic_device_type(const cl::sycl::device& device); device_index_type get_id() const; ccl_native_t& get() noexcept; + const ccl_native_t& get() const noexcept; cl::sycl::device device; }; template <> -struct generic_device_context_type { +struct generic_context_type { using handle_t = cl_context; using impl_t = cl::sycl::context; using ccl_native_t = impl_t; - generic_device_context_type(); - generic_device_context_type(ccl_native_t ctx); + generic_context_type(); + generic_context_type(ccl_native_t ctx); ccl_native_t& get() noexcept; const ccl_native_t& get() const noexcept; @@ -102,4 +105,4 @@ struct generic_event_type { API_CLASS_TYPE_INFO(cl_command_queue); API_CLASS_TYPE_INFO(cl_context); API_CLASS_TYPE_INFO(cl_event) -} +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_stream.hpp b/include/oneapi/ccl/stream.hpp similarity index 63% rename from include/oneapi/ccl/ccl_stream.hpp rename to include/oneapi/ccl/stream.hpp index 2dc7f82f9..9fbb07089 100644 --- a/include/oneapi/ccl/ccl_stream.hpp +++ b/include/oneapi/ccl/stream.hpp @@ -21,10 +21,18 @@ class ccl_stream; namespace ccl { +namespace detail { +class environment; +} + +namespace v1 { +struct ccl_empty_attr; +class communicator; +struct impl_dispatch; /** * A stream object is an abstraction over CPU/GPU streams - * Has no defined public constructor. Use ccl::environment::create_stream + * Has no defined public constructor. Use ccl::create_stream * for stream objects creation */ /** @@ -47,8 +55,9 @@ class stream : public ccl_api_base_copyable::return_type; + using native_t = + typename detail::ccl_api_type_attr_traits::return_type; ~stream(); @@ -60,25 +69,25 @@ class stream : public ccl_api_base_copyable - const typename details::ccl_api_type_attr_traits::return_type& get() + const typename detail::ccl_api_type_attr_traits::return_type& get() const; /** * Get native stream object */ - native_t& get_native(); - const native_t& get_native() const; + native_t& get_native(); + const native_t& get_native() const; + private: - friend class environment; - friend class communicator; - friend struct ccl_empty_attr; - friend struct impl_dispatch; + friend class ccl::detail::environment; + friend class ccl::v1::communicator; + friend struct ccl::ccl_empty_attr; + friend struct ccl::v1::impl_dispatch; template - friend stream create_stream_from_attr( - typename unified_device_type::ccl_native_t device, - typename unified_device_context_type::ccl_native_t context, - attr_value_pair_t&&... avps); + friend stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, + typename unified_context_type::ccl_native_t context, + attr_value_pair_t&&... avps); template friend stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, attr_value_pair_t&&... avps); @@ -91,12 +100,11 @@ class stream : public ccl_api_base_copyable()>::type*/> - typename details::ccl_api_type_attr_traits::return_type set(const Value& v); + typename detail::ccl_api_type_attr_traits::return_type set(const Value& v); void build_from_params(); - stream( - const typename details::ccl_api_type_attr_traits::type& version); + stream(const typename detail::ccl_api_type_attr_traits::type& version); /** * Factory methods @@ -115,20 +123,25 @@ class stream : public ccl_api_base_copyable - static stream create_stream_from_attr( - typename unified_device_type::ccl_native_t device, - typename unified_device_context_type::ccl_native_t context, - attr_value_pair_t&&... avps); + static stream create_stream_from_attr(typename unified_device_type::ccl_native_t device, + typename unified_context_type::ccl_native_t context, + attr_value_pair_t&&... avps); }; -template -constexpr auto attr_val(value_type v) - -> details::attr_value_tripple { - return details::attr_value_tripple(v); -} - /** * Declare extern empty attributes */ extern stream default_stream; + +template +constexpr auto attr_val(value_type v) -> detail::attr_value_triple { + return detail::attr_value_triple(v); +} + +} // namespace v1 + +using v1::stream; +using v1::default_stream; +using v1::attr_val; + } // namespace ccl diff --git a/include/oneapi/ccl/ccl_stream_attr_ids.hpp b/include/oneapi/ccl/stream_attr_ids.hpp similarity index 93% rename from include/oneapi/ccl/ccl_stream_attr_ids.hpp rename to include/oneapi/ccl/stream_attr_ids.hpp index d86dfedab..17080c98b 100644 --- a/include/oneapi/ccl/ccl_stream_attr_ids.hpp +++ b/include/oneapi/ccl/stream_attr_ids.hpp @@ -21,6 +21,9 @@ class ccl_stream; namespace ccl { + +namespace v1 { + /** * Stream attribute ids */ @@ -39,4 +42,8 @@ enum class stream_attr_id : int { last_value }; +} // namespace v1 + +using v1::stream_attr_id; + } // namespace ccl diff --git a/include/oneapi/ccl/ccl_stream_attr_ids_traits.hpp b/include/oneapi/ccl/stream_attr_ids_traits.hpp similarity index 92% rename from include/oneapi/ccl/ccl_stream_attr_ids_traits.hpp rename to include/oneapi/ccl/stream_attr_ids_traits.hpp index c50f05075..9d44e8f01 100644 --- a/include/oneapi/ccl/ccl_stream_attr_ids_traits.hpp +++ b/include/oneapi/ccl/stream_attr_ids_traits.hpp @@ -20,7 +20,8 @@ #endif namespace ccl { -namespace details { + +namespace detail { /** * Traits for stream attributes specializations @@ -47,8 +48,8 @@ struct ccl_api_type_attr_traits { template <> struct ccl_api_type_attr_traits { - using type = typename unified_device_context_type::ccl_native_t; - using handle_t = typename unified_device_context_type::handle_t; + using type = typename unified_context_type::ccl_native_t; + using handle_t = typename unified_context_type::handle_t; using return_type = type; }; @@ -82,5 +83,6 @@ struct ccl_api_type_attr_traits { using return_type = type; }; -} // namespace details +} // namespace detail + } // namespace ccl diff --git a/include/oneapi/ccl/string.hpp b/include/oneapi/ccl/string.hpp new file mode 100644 index 000000000..072642012 --- /dev/null +++ b/include/oneapi/ccl/string.hpp @@ -0,0 +1,161 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include +#include +#include + +namespace ccl { + +namespace v1 { + +class string { +public: + ~string() { + delete[] storage; + storage = nullptr; + len = 0; + } + + string() { + storage = new char[1]; + *storage = '\0'; + len = 0; + } + + string(const char* str) { + len = strlen(str); + storage = new char[len + 1]; + memcpy(storage, str, len * sizeof(char)); + storage[len] = '\0'; + } + + string(const string& str) { + len = str.len; + storage = new char[len + 1]; + memcpy(storage, str.storage, len * sizeof(char)); + storage[len] = '\0'; + } + + string(string&& str) noexcept { + len = str.len; + storage = str.storage; + str.len = 0; + str.storage = nullptr; + } + + string(const std::string& str) { + len = str.length(); + storage = new char[len + 1]; + memcpy(storage, str.c_str(), len * sizeof(char)); + storage[len] = '\0'; + } + + string& operator=(const string& str) { + if (this != &str) { + if (len != str.len) { + len = str.len; + delete[] storage; + storage = new char[len + 1]; + } + memcpy(storage, str.storage, len * sizeof(char)); + storage[len] = '\0'; + } + return *this; + } + + string& operator=(string&& str) noexcept { + len = str.len; + storage = str.storage; + str.len = 0; + str.storage = nullptr; + return *this; + } + + size_t length() const { + return len; + } + + const char* c_str() const { + return storage; + }; + + operator std::string() const { + return std::string(storage); + } + + friend std::ostream& operator<<(std::ostream& out, const string& str) { + out << str.storage; + return out; + } + + string operator+(const char* str) { + auto str_len = strlen(str); + if (str_len > 0) { + auto new_storage = new char[len + str_len + 1]; + memcpy(new_storage, storage, len * sizeof(char)); + memcpy(&new_storage[len], str, str_len * sizeof(char)); + new_storage[len + str_len] = '\0'; + string res(new_storage); + delete[] new_storage; + return res; + } + return string(storage); + } + + string operator+(const string& str) { + return (*this + str.c_str()); + } + + string operator+(const std::string& str) { + return (*this + str.c_str()); + } + + friend std::string operator+(const std::string& str1, const string& str2) { + return (str1 + str2.c_str()); + } + + friend bool operator>(const string& str1, const string& str2) { + return strcmp(str1.c_str(), str2.c_str()) > 0; + } + + friend bool operator<=(const string& str1, const string& str2) { + return strcmp(str1.c_str(), str2.c_str()) <= 0; + } + + friend bool operator<(const string& str1, const string& str2) { + return strcmp(str1.c_str(), str2.c_str()) < 0; + } + + friend bool operator>=(const string& str1, const string& str2) { + return strcmp(str1.c_str(), str2.c_str()) >= 0; + } + + friend bool operator==(const string& str1, const string& str2) { + return strcmp(str1.c_str(), str2.c_str()) == 0; + } + +private: + size_t len; + char* storage; +}; + +} // namespace v1 + +using v1::string; + +} // namespace ccl diff --git a/include/oneapi/ccl/type_traits.hpp b/include/oneapi/ccl/type_traits.hpp new file mode 100644 index 000000000..43db30e95 --- /dev/null +++ b/include/oneapi/ccl/type_traits.hpp @@ -0,0 +1,168 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include +#include + +#ifdef CCL_ENABLE_SYCL +#include +#endif + +#include "oneapi/ccl/lp_types.hpp" +#include "oneapi/ccl/types.hpp" + +namespace ccl { +/** + * Base type-trait helpers for "unknown" types + */ +template +struct type_info { + static constexpr bool is_supported = false; + static constexpr bool is_class = false; +}; + +template +struct native_type_info { + static constexpr bool is_supported = false; + static constexpr bool is_class = false; +}; + +#define CCL_TYPE_TRAITS(ccl_type, cpp_type, bytes, str) \ + template <> \ + struct type_info \ + : public ccl_type_info_export { \ + static constexpr const char* name() { \ + return #str; \ + } \ + }; \ + template <> \ + struct native_type_info : public type_info {}; + +#define CCL_CLASS_TYPE_TRAITS(ccl_type, cpp_type, bytes, str) \ + template <> \ + struct native_type_info \ + : public ccl_type_info_export { \ + static constexpr const char* name() { \ + return #str; \ + } \ + }; + +#define COMMA , + +/** + * Enumeration of supported CCL API data types + */ + +CCL_TYPE_TRAITS(ccl::datatype::int8, int8_t, sizeof(int8_t), int8) +CCL_TYPE_TRAITS(ccl::datatype::uint8, uint8_t, sizeof(uint8_t), uint8) +CCL_TYPE_TRAITS(ccl::datatype::int16, int16_t, sizeof(int16_t), int16) +CCL_TYPE_TRAITS(ccl::datatype::uint16, uint16_t, sizeof(uint16_t), uint16) +CCL_TYPE_TRAITS(ccl::datatype::int32, int32_t, sizeof(int32_t), int32) +CCL_TYPE_TRAITS(ccl::datatype::uint32, uint32_t, sizeof(uint32_t), uint32) +CCL_TYPE_TRAITS(ccl::datatype::int64, int64_t, sizeof(int64_t), int64) +CCL_TYPE_TRAITS(ccl::datatype::uint64, uint64_t, sizeof(uint64_t), uint64) +//CCL_TYPE_TRAITS(ccl::datatype::float16, float16, sizeof(float16), float16) +CCL_TYPE_TRAITS(ccl::datatype::float32, float, sizeof(float), float32) +CCL_TYPE_TRAITS(ccl::datatype::float64, double, sizeof(double), float64) +CCL_TYPE_TRAITS(ccl::datatype::bfloat16, bfloat16, sizeof(bfloat16), bfloat16) + +#ifdef CCL_ENABLE_SYCL +CCL_CLASS_TYPE_TRAITS(ccl::datatype::int8, cl::sycl::buffer, sizeof(int8_t), int8) +CCL_CLASS_TYPE_TRAITS(ccl::datatype::uint8, + cl::sycl::buffer, + sizeof(uint8_t), + uint8) +CCL_CLASS_TYPE_TRAITS(ccl::datatype::int16, + cl::sycl::buffer, + sizeof(int16_t), + int16) +CCL_CLASS_TYPE_TRAITS(ccl::datatype::uint16, + cl::sycl::buffer, + sizeof(uint16_t), + uint16) +CCL_CLASS_TYPE_TRAITS(ccl::datatype::int32, + cl::sycl::buffer, + sizeof(int32_t), + int32) +CCL_CLASS_TYPE_TRAITS(ccl::datatype::uint32, + cl::sycl::buffer, + sizeof(uint32_t), + uint32) +CCL_CLASS_TYPE_TRAITS(ccl::datatype::int64, + cl::sycl::buffer, + sizeof(int64_t), + int64) +CCL_CLASS_TYPE_TRAITS(ccl::datatype::uint64, + cl::sycl::buffer, + sizeof(uint64_t), + uint64) +// CCL_CLASS_TYPE_TRAITS(ccl::datatype::float16, +// cl::sycl::buffer, +// sizeof(float16), +// float16) +CCL_CLASS_TYPE_TRAITS(ccl::datatype::float32, + cl::sycl::buffer, + sizeof(float), + float32) +CCL_CLASS_TYPE_TRAITS(ccl::datatype::float64, + cl::sycl::buffer, + sizeof(double), + float64) +CCL_CLASS_TYPE_TRAITS(ccl::datatype::bfloat16, + cl::sycl::buffer, + sizeof(bfloat16), + bfloat16) +#endif /* CCL_ENABLE_SYCL */ + +/** + * Checks for supporting @c type in ccl API + */ +template +constexpr bool is_supported() { + using clear_type = typename std::remove_pointer::type; + // static_assert(native_type_info::is_supported, "type is not supported by ccl API"); + return native_type_info::is_supported; +} + +/** + * Checks is @c type a class + */ +template +constexpr bool is_class() { + using clear_type = typename std::remove_pointer::type; + return native_type_info::is_class; +} + +/** + * SFINAE checks for supporting native type @c type in ccl API + */ +template +constexpr bool is_native_type_supported() { + return (not is_class() and is_supported()); +} + +/** + * SFINAE checks for supporting class @c type in ccl API + */ +template +constexpr bool is_class_supported() { + return (is_class() and is_supported()); +} + +} // namespace ccl + +#include "oneapi/ccl/device_type_traits.hpp" diff --git a/include/oneapi/ccl/types.hpp b/include/oneapi/ccl/types.hpp new file mode 100644 index 000000000..ce7b2bb14 --- /dev/null +++ b/include/oneapi/ccl/types.hpp @@ -0,0 +1,195 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include +#include +#include "oneapi/ccl/config.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/exception.hpp" + +namespace ccl { + +namespace v1 { + +/** + * Supported reduction operations + */ +enum class reduction : int { + sum = 0, + prod, + min, + max, + custom, + + last_value +}; + +/** + * Supported datatypes + */ +enum class datatype : int { + int8 = 0, + uint8, + int16, + uint16, + int32, + uint32, + int64, + uint64, + + float16, + float32, + float64, + + bfloat16, + + last_predefined = bfloat16 +}; + +/** + * Supported CL backend types + */ +enum class cl_backend_type : int { + empty_backend = 0x0, + dpcpp_sycl = 0x1, + l0 = 0x2, + dpcpp_sycl_l0 = 0x3, + + last_value +}; + +} // namespace v1 + +using v1::reduction; +using v1::datatype; +using v1::cl_backend_type; + +/** + * Type traits, which describes how-to types would be interpretered by ccl API + */ +template +struct ccl_type_info_export { + using native_type = ntype_t; + using ccl_type = std::integral_constant; + static constexpr size_t size = size_of_type; + static constexpr datatype dtype = static_cast(ccl_type::value); + static constexpr bool is_class = iclass; + static constexpr bool is_supported = supported; +}; + +namespace v1 { + +/** + * Library version description + */ +typedef struct { + unsigned int major; + unsigned int minor; + unsigned int update; + const char* product_status; + const char* build_date; + const char* full; + string_class cl_backend_name; +} library_version; + +typedef struct { + const char* match_id; + const size_t offset; +} fn_context; + +/* in_buf, in_count, inout_buf, out_count, dtype, context */ +typedef void ( + *reduction_fn)(const void*, size_t, void*, size_t*, ccl::datatype, const ccl::v1::fn_context*); + +struct ccl_empty_attr { + static ccl::v1::library_version version; + + template + static attr create_empty(); +}; + +/** + * Sparse coalesce modes + * + * Use this variable to set sparse_allreduce coalescing mode: + * regular - run regular coalesce funtion; + * disable - disables coalesce function in sparse_allreduce, + * allgathered data is returned; + * keep_precision - on every local reduce bf16 data is converted to fp32, + * reduced and then converted back to bf16. + */ +enum class sparse_coalesce_mode : int { + regular = 0, + disable, + keep_precision, + + last_value +}; + +/* idx_buf, idx_count, idx_dtype, val_buf, val_count, val_dtype, user_context */ +typedef void (*sparse_allreduce_completion_fn)(const void*, + size_t, + ccl::datatype, + const void*, + size_t, + ccl::datatype, + const void*); + +/* idx_count, idx_dtype, val_count, val_dtype, user_context, out_idx_buf, out_val_buf */ +typedef void (*sparse_allreduce_alloc_fn)(size_t, + ccl::datatype, + size_t, + ccl::datatype, + const void*, + void**, + void**); +} // namespace v1 + +using v1::library_version; +using v1::fn_context; +using v1::reduction_fn; +using v1::ccl_empty_attr; + +using v1::sparse_coalesce_mode; +using v1::sparse_allreduce_completion_fn; +using v1::sparse_allreduce_alloc_fn; + +/** + * API object attributes traits + */ +namespace info { +template +struct param_traits {}; + +} //namespace info +} // namespace ccl + +#include "oneapi/ccl/device_types.hpp" diff --git a/include/oneapi/ccl/ccl_types_policy.hpp b/include/oneapi/ccl/types_policy.hpp similarity index 97% rename from include/oneapi/ccl/ccl_types_policy.hpp rename to include/oneapi/ccl/types_policy.hpp index aebedd25a..4ebd0bffc 100644 --- a/include/oneapi/ccl/ccl_types_policy.hpp +++ b/include/oneapi/ccl/types_policy.hpp @@ -188,19 +188,20 @@ class ccl_api_base_movable : protected access_policy_t { private: impl_value_t pimpl; }; -namespace details { + +namespace detail { template struct ccl_api_type_attr_traits {}; template -struct attr_value_tripple { +struct attr_value_triple { using type_t = attrib_id; using value_t = value_type; static constexpr attrib_id idx() { return attrId; } - explicit attr_value_tripple(value_t val) : m_val(val) {} + explicit attr_value_triple(value_t val) : m_val(val) {} const value_type& val() { return m_val; } @@ -208,6 +209,6 @@ struct attr_value_tripple { private: value_t m_val; }; +} // namespace detail -} // namespace details } // namespace ccl diff --git a/mpi/bin/hydra_bstrap_proxy b/mpi/bin/hydra_bstrap_proxy index 6ce0a503d..03d31da92 100755 Binary files a/mpi/bin/hydra_bstrap_proxy and b/mpi/bin/hydra_bstrap_proxy differ diff --git a/mpi/bin/hydra_nameserver b/mpi/bin/hydra_nameserver index b6471ef4a..dc1b3e064 100755 Binary files a/mpi/bin/hydra_nameserver and b/mpi/bin/hydra_nameserver differ diff --git a/mpi/bin/hydra_pmi_proxy b/mpi/bin/hydra_pmi_proxy index 67b22b917..73978d8d1 100755 Binary files a/mpi/bin/hydra_pmi_proxy and b/mpi/bin/hydra_pmi_proxy differ diff --git a/mpi/bin/mpiexec b/mpi/bin/mpiexec deleted file mode 120000 index 482a69296..000000000 --- a/mpi/bin/mpiexec +++ /dev/null @@ -1 +0,0 @@ -mpiexec.hydra \ No newline at end of file diff --git a/mpi/bin/mpiexec b/mpi/bin/mpiexec new file mode 100755 index 000000000..423fc1a36 Binary files /dev/null and b/mpi/bin/mpiexec differ diff --git a/mpi/bin/mpiexec.hydra b/mpi/bin/mpiexec.hydra index 3e67f050e..423fc1a36 100755 Binary files a/mpi/bin/mpiexec.hydra and b/mpi/bin/mpiexec.hydra differ diff --git a/mpi/bin/mpigcc b/mpi/bin/mpigcc index 1be1d3ef5..ad2cc08a6 100755 --- a/mpi/bin/mpigcc +++ b/mpi/bin/mpigcc @@ -104,7 +104,7 @@ CFLAGS="" CPPFLAGS="" LDFLAGS=" -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl " LIBS="-lm -lpthread -lfabric -lrt " -MPIVERSION="2021.1-beta10" +MPIVERSION="2021.1" MPILIBNAME="mpi" diff --git a/mpi/bin/mpigxx b/mpi/bin/mpigxx index 2d3a8ad14..5524e9a49 100755 --- a/mpi/bin/mpigxx +++ b/mpi/bin/mpigxx @@ -101,7 +101,7 @@ MPICH_VERSION="3.3" CXXFLAGS="" LDFLAGS=" -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl " LIBS="-lm -lpthread -lfabric -lrt " -MPIVERSION="2021.1-beta10" +MPIVERSION="2021.1" MPILIBNAME="mpi" MPICXXLIBNAME="mpicxx" diff --git a/mpi/bin/mpiicc b/mpi/bin/mpiicc index b938fe315..aff4c8024 100755 --- a/mpi/bin/mpiicc +++ b/mpi/bin/mpiicc @@ -122,7 +122,7 @@ MPILIBNAME="mpi" PMPILIBNAME="pmpi" # MPIVERSION is the version of the MPICH2 library that mpicc is intended for -MPIVERSION="2021.1-beta10" +MPIVERSION="2021.1" # # Internal variables # Show is set to echo to cause the compilation command to be echoed instead diff --git a/mpi/bin/mpiicpc b/mpi/bin/mpiicpc index 62ce4df55..1e172575b 100755 --- a/mpi/bin/mpiicpc +++ b/mpi/bin/mpiicpc @@ -121,7 +121,7 @@ PMPILIBNAME="pmpi" MPICXXLIBNAME="mpicxx" # MPIVERSION is the version of the Intel(R) MPI Library that mpiicpc is intended for -MPIVERSION="2021.1-beta10" +MPIVERSION="2021.1" # # Internal variables # Show is set to echo to cause the compilation command to be echoed instead diff --git a/mpi/etc/tuning_clx-ap_shm-ofi.dat b/mpi/etc/tuning_clx-ap_shm-ofi.dat index 841f2d778..ce04f0d57 100755 Binary files a/mpi/etc/tuning_clx-ap_shm-ofi.dat and b/mpi/etc/tuning_clx-ap_shm-ofi.dat differ diff --git a/mpi/include/mpi.h b/mpi/include/mpi.h index 301e0e9eb..de4cba113 100755 --- a/mpi/include/mpi.h +++ b/mpi/include/mpi.h @@ -595,8 +595,8 @@ typedef int (MPI_Delete_function) ( MPI_Comm, int, void *, void * ); * digits for REV, 1 digit for EXT and 2 digits for EXT_NUMBER. So, * 2019.0.0b0 will have the numeric version 20190000100. */ -#define I_MPI_VERSION "2021.1.0b10" -#define I_MPI_NUMVERSION 20210100110 +#define I_MPI_VERSION "2021.1.0" +#define I_MPI_NUMVERSION 20210100300 /* for the datatype decoders */ enum MPIR_Combiner_enum { diff --git a/mpi/lib/libmpi.so b/mpi/lib/libmpi.so deleted file mode 120000 index 9e4b9f431..000000000 --- a/mpi/lib/libmpi.so +++ /dev/null @@ -1 +0,0 @@ -libmpi.so.12.0 \ No newline at end of file diff --git a/mpi/lib/libmpi.so b/mpi/lib/libmpi.so new file mode 100755 index 000000000..caeb9d1ac Binary files /dev/null and b/mpi/lib/libmpi.so differ diff --git a/mpi/lib/libmpi.so.12 b/mpi/lib/libmpi.so.12 deleted file mode 120000 index 5a0e391d4..000000000 --- a/mpi/lib/libmpi.so.12 +++ /dev/null @@ -1 +0,0 @@ -libmpi.so.12.0.0 \ No newline at end of file diff --git a/mpi/lib/libmpi.so.12 b/mpi/lib/libmpi.so.12 new file mode 100755 index 000000000..caeb9d1ac Binary files /dev/null and b/mpi/lib/libmpi.so.12 differ diff --git a/mpi/lib/libmpi.so.12.0 b/mpi/lib/libmpi.so.12.0 deleted file mode 120000 index 5a0e391d4..000000000 --- a/mpi/lib/libmpi.so.12.0 +++ /dev/null @@ -1 +0,0 @@ -libmpi.so.12.0.0 \ No newline at end of file diff --git a/mpi/lib/libmpi.so.12.0 b/mpi/lib/libmpi.so.12.0 new file mode 100755 index 000000000..caeb9d1ac Binary files /dev/null and b/mpi/lib/libmpi.so.12.0 differ diff --git a/mpi/lib/libmpi.so.12.0.0 b/mpi/lib/libmpi.so.12.0.0 index 54cbc716e..caeb9d1ac 100755 Binary files a/mpi/lib/libmpi.so.12.0.0 and b/mpi/lib/libmpi.so.12.0.0 differ diff --git a/mpi/lib/libmpicxx.so b/mpi/lib/libmpicxx.so deleted file mode 120000 index 9e27e2a69..000000000 --- a/mpi/lib/libmpicxx.so +++ /dev/null @@ -1 +0,0 @@ -libmpicxx.so.12.0.0 \ No newline at end of file diff --git a/mpi/lib/libmpicxx.so b/mpi/lib/libmpicxx.so new file mode 100755 index 000000000..ee69659ef Binary files /dev/null and b/mpi/lib/libmpicxx.so differ diff --git a/mpi/lib/libmpicxx.so.12 b/mpi/lib/libmpicxx.so.12 deleted file mode 120000 index 9e27e2a69..000000000 --- a/mpi/lib/libmpicxx.so.12 +++ /dev/null @@ -1 +0,0 @@ -libmpicxx.so.12.0.0 \ No newline at end of file diff --git a/mpi/lib/libmpicxx.so.12 b/mpi/lib/libmpicxx.so.12 new file mode 100755 index 000000000..ee69659ef Binary files /dev/null and b/mpi/lib/libmpicxx.so.12 differ diff --git a/mpi/lib/libmpicxx.so.12.0 b/mpi/lib/libmpicxx.so.12.0 deleted file mode 120000 index 9e27e2a69..000000000 --- a/mpi/lib/libmpicxx.so.12.0 +++ /dev/null @@ -1 +0,0 @@ -libmpicxx.so.12.0.0 \ No newline at end of file diff --git a/mpi/lib/libmpicxx.so.12.0 b/mpi/lib/libmpicxx.so.12.0 new file mode 100755 index 000000000..ee69659ef Binary files /dev/null and b/mpi/lib/libmpicxx.so.12.0 differ diff --git a/mpi/lib/libmpifort.so b/mpi/lib/libmpifort.so deleted file mode 120000 index 3dc64470d..000000000 --- a/mpi/lib/libmpifort.so +++ /dev/null @@ -1 +0,0 @@ -libmpifort.so.12.0.0 \ No newline at end of file diff --git a/mpi/lib/libmpifort.so b/mpi/lib/libmpifort.so new file mode 100755 index 000000000..7e12bff9b Binary files /dev/null and b/mpi/lib/libmpifort.so differ diff --git a/mpi/lib/libmpifort.so.12 b/mpi/lib/libmpifort.so.12 deleted file mode 120000 index 3dc64470d..000000000 --- a/mpi/lib/libmpifort.so.12 +++ /dev/null @@ -1 +0,0 @@ -libmpifort.so.12.0.0 \ No newline at end of file diff --git a/mpi/lib/libmpifort.so.12 b/mpi/lib/libmpifort.so.12 new file mode 100755 index 000000000..7e12bff9b Binary files /dev/null and b/mpi/lib/libmpifort.so.12 differ diff --git a/mpi/lib/libmpifort.so.12.0 b/mpi/lib/libmpifort.so.12.0 deleted file mode 120000 index 3dc64470d..000000000 --- a/mpi/lib/libmpifort.so.12.0 +++ /dev/null @@ -1 +0,0 @@ -libmpifort.so.12.0.0 \ No newline at end of file diff --git a/mpi/lib/libmpifort.so.12.0 b/mpi/lib/libmpifort.so.12.0 new file mode 100755 index 000000000..7e12bff9b Binary files /dev/null and b/mpi/lib/libmpifort.so.12.0 differ diff --git a/mpi/lib/libmpifort.so.12.0.0 b/mpi/lib/libmpifort.so.12.0.0 index e8bf6822d..7e12bff9b 100755 Binary files a/mpi/lib/libmpifort.so.12.0.0 and b/mpi/lib/libmpifort.so.12.0.0 differ diff --git a/mpi/licensing/license.txt b/mpi/licensing/license.txt old mode 100644 new mode 100755 index 6ca7736f1..ffffdc860 --- a/mpi/licensing/license.txt +++ b/mpi/licensing/license.txt @@ -1,77 +1,77 @@ -Intel Simplified Software License (Version April 2018) +Intel Simplified Software License (Version February 2020) -Copyright (c) 2018 Intel Corporation. - -Use and Redistribution. You may use and redistribute the software (the +Use and Redistribution. You may use and redistribute the software (the "Software"), without modification, provided the following conditions are met: -* Redistributions must reproduce the above copyright notice and the following - terms of use in the Software and in the documentation and/or other materials +* Redistributions must reproduce the above copyright notice and the following + terms of use in the Software and in the documentation and/or other materials provided with the distribution. -* Neither the name of Intel nor the names of its suppliers may be used to - endorse or promote products derived from this Software without specific prior +* Neither the name of Intel nor the names of its suppliers may be used to + endorse or promote products derived from this Software without specific prior written permission. -* No reverse engineering, decompilation, or disassembly of this Software is +* No reverse engineering, decompilation, or disassembly of this Software is permitted. -Limited patent license. Intel grants you a world-wide, royalty-free, -non-exclusive license under patents it now or hereafter owns or controls to -make, have made, use, import, offer to sell and sell ("Utilize") this Software, -but solely to the extent that any such patent is necessary to Utilize the -Software alone. The patent license shall not apply to any combinations which +Limited patent license. Intel grants you a world-wide, royalty-free, +non-exclusive license under patents it now or hereafter owns or controls to +make, have made, use, import, offer to sell and sell ("Utilize") this Software, +but solely to the extent that any such patent is necessary to Utilize the +Software alone. The patent license shall not apply to any combinations which include this software. No hardware per se is licensed hereunder. -Third party and other Intel programs. "Third Party Programs" are the files -listed in the "third-party-programs.txt" text file that is included with the -Software and may include Intel programs under separate license terms. Third -Party Programs, even if included with the distribution of the Materials, are -governed by separate license terms and those license terms solely govern your -use of those programs. +Third party programs. The Software may contain Third Party Programs. "Third +Party Programs" are third party software, open source software or other Intel +software listed in the "third-party-programs.txt" or other similarly named text +file that is included with the Software. Third Party Programs, even if included +with the distribution of the Software, may be governed by separate license +terms, including without limitation, third party license terms, open source +software notices and terms, and/or other Intel software license terms. These +separate license terms may govern your use of the Third Party Programs. -DISCLAIMER. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED -WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE -DISCLAIMED. THIS SOFTWARE IS NOT INTENDED FOR USE IN SYSTEMS OR APPLICATIONS -WHERE FAILURE OF THE SOFTWARE MAY CAUSE PERSONAL INJURY OR DEATH AND YOU AGREE -THAT YOU ARE FULLY RESPONSIBLE FOR ANY CLAIMS, COSTS, DAMAGES, EXPENSES, AND -ATTORNEYS' FEES ARISING OUT OF ANY SUCH USE, EVEN IF ANY CLAIM ALLEGES THAT +DISCLAIMER. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED +WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE +DISCLAIMED. THIS SOFTWARE IS NOT INTENDED FOR USE IN SYSTEMS OR APPLICATIONS +WHERE FAILURE OF THE SOFTWARE MAY CAUSE PERSONAL INJURY OR DEATH AND YOU AGREE +THAT YOU ARE FULLY RESPONSIBLE FOR ANY CLAIMS, COSTS, DAMAGES, EXPENSES, AND +ATTORNEYS' FEES ARISING OUT OF ANY SUCH USE, EVEN IF ANY CLAIM ALLEGES THAT INTEL WAS NEGLIGENT REGARDING THE DESIGN OR MANUFACTURE OF THE MATERIALS. -LIMITATION OF LIABILITY. IN NO EVENT WILL INTEL BE LIABLE FOR ANY DIRECT, -INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE -OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF -ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. YOU AGREE TO INDEMNIFY AND HOLD INTEL -HARMLESS AGAINST ANY CLAIMS AND EXPENSES RESULTING FROM YOUR USE OR UNAUTHORIZED -USE OF THE SOFTWARE. +LIMITATION OF LIABILITY. IN NO EVENT WILL INTEL BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. YOU AGREE TO INDEMNIFY AND HOLD +INTEL HARMLESS AGAINST ANY CLAIMS AND EXPENSES RESULTING FROM YOUR USE OR +UNAUTHORIZED USE OF THE SOFTWARE. -No support. Intel may make changes to the Software, at any time without notice, -and is not obligated to support, update or provide training for the Software. +No support. Intel may make changes to the Software, at any time without notice, +and is not obligated to support, update or provide training for the Software. -Termination. Intel may terminate your right to use the Software in the event of -your breach of this Agreement and you fail to cure the breach within a +Termination. Intel may terminate your right to use the Software in the event of +your breach of this Agreement and you fail to cure the breach within a reasonable period of time. -Feedback. Should you provide Intel with comments, modifications, corrections, -enhancements or other input ("Feedback") related to the Software Intel will be -free to use, disclose, reproduce, license or otherwise distribute or exploit the -Feedback in its sole discretion without any obligations or restrictions of any -kind, including without limitation, intellectual property rights or licensing +Feedback. Should you provide Intel with comments, modifications, corrections, +enhancements or other input ("Feedback") related to the Software Intel will be +free to use, disclose, reproduce, license or otherwise distribute or exploit the +Feedback in its sole discretion without any obligations or restrictions of any +kind, including without limitation, intellectual property rights or licensing obligations. -Compliance with laws. You agree to comply with all relevant laws and -regulations governing your use, transfer, import or export (or prohibition -thereof) of the Software. +Compliance with laws. You agree to comply with all relevant laws and regulations +governing your use, transfer, import or export (or prohibition thereof) of the +Software. -Governing law. All disputes will be governed by the laws of the United States -of America and the State of Delaware without reference to conflict of law -principles and subject to the exclusive jurisdiction of the state or federal -courts sitting in the State of Delaware, and each party agrees that it submits -to the personal jurisdiction and venue of those courts and waives any -objections. The United Nations Convention on Contracts for the International -Sale of Goods (1980) is specifically excluded and will not apply to the +Governing law. All disputes will be governed by the laws of the United States of +America and the State of Delaware without reference to conflict of law +principles and subject to the exclusive jurisdiction of the state or federal +courts sitting in the State of Delaware, and each party agrees that it submits +to the personal jurisdiction and venue of those courts and waives any +objections. The United Nations Convention on Contracts for the International +Sale of Goods (1980) is specifically excluded and will not apply to the Software. *Other names and brands may be claimed as the property of others. diff --git a/mpi/licensing/third-party-programs.txt b/mpi/licensing/third-party-programs.txt index 34278eb40..0dbbc92ff 100755 --- a/mpi/licensing/third-party-programs.txt +++ b/mpi/licensing/third-party-programs.txt @@ -1,4 +1,4 @@ -Intel(R) MPI Library 2021.1-beta10 Third Party Programs File +Intel(R) MPI Library 2021.1 Third Party Programs File This file is the "third-party-programs.txt" file specified in the associated Intel end user license agreement for the Intel software you are licensing. diff --git a/ofi/bin/fi_info b/ofi/bin/fi_info index d50d3dfba..463945e25 100755 Binary files a/ofi/bin/fi_info and b/ofi/bin/fi_info differ diff --git a/ofi/include/rdma/fabric.h b/ofi/include/rdma/fabric.h index 9e53861e2..ce0b91805 100644 --- a/ofi/include/rdma/fabric.h +++ b/ofi/include/rdma/fabric.h @@ -16,6 +16,7 @@ /* * Copyright (c) 2013-2017 Intel Corporation. All rights reserved. * Copyright (c) 2016 Cisco Systems, Inc. All rights reserved. + * (C) Copyright 2020 Hewlett Packard Enterprise Development LP * * This software is available to you under a choice of one of two * licenses. You may choose to be licensed under the terms of the GNU @@ -53,6 +54,7 @@ #include #include #include +#include #ifdef __GNUC__ #define FI_DEPRECATED_FUNC __attribute__((deprecated)) @@ -92,8 +94,8 @@ extern "C" { #endif #define FI_MAJOR_VERSION 1 -#define FI_MINOR_VERSION 10 -#define FI_REVISION_VERSION 1 +#define FI_MINOR_VERSION 11 +#define FI_REVISION_VERSION 0 enum { FI_PATH_MAX = 256, @@ -167,6 +169,7 @@ typedef struct fid *fid_t; #define FI_PEEK (1ULL << 19) #define FI_TRIGGER (1ULL << 20) #define FI_FENCE (1ULL << 21) +#define FI_PRIORITY (1ULL << 22) #define FI_COMPLETION (1ULL << 24) #define FI_EVENT FI_COMPLETION @@ -545,6 +548,8 @@ struct fi_ops { int (*ops_open)(struct fid *fid, const char *name, uint64_t flags, void **ops, void *context); int (*tostr)(const struct fid *fid, char *buf, size_t len); + int (*ops_set)(struct fid *fid, const char *name, uint64_t flags, + void *ops, void *context); }; /* All fabric interface descriptors must start with this structure */ @@ -664,6 +669,14 @@ fi_open_ops(struct fid *fid, const char *name, uint64_t flags, return fid->ops->ops_open(fid, name, flags, ops, context); } +static inline int +fi_set_ops(struct fid *fid, const char *name, uint64_t flags, + void *ops, void *context) +{ + return FI_CHECK_OP(fid->ops, struct fi_ops, ops_set) ? + fid->ops->ops_set(fid, name, flags, ops, context) : -FI_ENOSYS; +} + enum fi_type { FI_TYPE_INFO, FI_TYPE_EP_TYPE, @@ -690,6 +703,7 @@ enum fi_type { FI_TYPE_OP_TYPE, FI_TYPE_FID, FI_TYPE_COLLECTIVE_OP, + FI_TYPE_HMEM_IFACE, }; char *fi_tostr(const void *data, enum fi_type datatype); diff --git a/ofi/include/rdma/fi_domain.h b/ofi/include/rdma/fi_domain.h index 3682b019a..99cde56cc 100644 --- a/ofi/include/rdma/fi_domain.h +++ b/ofi/include/rdma/fi_domain.h @@ -15,6 +15,7 @@ */ /* * Copyright (c) 2013-2017 Intel Corporation. All rights reserved. + * (C) Copyright 2020 Hewlett Packard Enterprise Development LP * * This software is available to you under a choice of one of two * licenses. You may choose to be licensed under the terms of the GNU @@ -130,6 +131,8 @@ struct fid_mr { enum fi_hmem_iface { FI_HMEM_SYSTEM = 0, FI_HMEM_CUDA, + FI_HMEM_ROCR, + FI_HMEM_ZE, }; struct fi_mr_attr { @@ -145,6 +148,7 @@ struct fi_mr_attr { union { uint64_t reserved; int cuda; + int ze; } device; }; @@ -153,6 +157,23 @@ struct fi_mr_modify { struct fi_mr_attr attr; }; +#define FI_SET_OPS_HMEM_OVERRIDE "hmem_override_ops" + +struct fi_hmem_override_ops { + size_t size; + + ssize_t (*copy_from_hmem_iov)(void *dest, size_t size, + enum fi_hmem_iface iface, uint64_t device, + const struct iovec *hmem_iov, + size_t hmem_iov_count, + uint64_t hmem_iov_offset); + + ssize_t (*copy_to_hmem_iov)(enum fi_hmem_iface iface, uint64_t device, + const struct iovec *hmem_iov, + size_t hmem_iov_count, + uint64_t hmem_iov_offset, const void *src, + size_t size); +}; #ifdef FABRIC_DIRECT #include @@ -258,8 +279,9 @@ struct fi_ops_domain { int (*query_atomic)(struct fid_domain *domain, enum fi_datatype datatype, enum fi_op op, struct fi_atomic_attr *attr, uint64_t flags); - int (*query_collective)(struct fid_domain *domain, enum fi_collective_op coll, - struct fi_collective_attr *attr, uint64_t flags); + int (*query_collective)(struct fid_domain *domain, + enum fi_collective_op coll, + struct fi_collective_attr *attr, uint64_t flags); }; /* Memory registration flags */ diff --git a/ofi/lib/libfabric.so b/ofi/lib/libfabric.so deleted file mode 120000 index 878a6164e..000000000 --- a/ofi/lib/libfabric.so +++ /dev/null @@ -1 +0,0 @@ -libfabric.so.1 \ No newline at end of file diff --git a/ofi/lib/libfabric.so b/ofi/lib/libfabric.so new file mode 100755 index 000000000..5a7dcef2a Binary files /dev/null and b/ofi/lib/libfabric.so differ diff --git a/ofi/lib/libfabric.so.1 b/ofi/lib/libfabric.so.1 index e4771f8ba..5a7dcef2a 100755 Binary files a/ofi/lib/libfabric.so.1 and b/ofi/lib/libfabric.so.1 differ diff --git a/ofi/lib/prov/libpsmx2-fi.so b/ofi/lib/prov/libpsmx2-fi.so index 33f555375..a925b003b 100755 Binary files a/ofi/lib/prov/libpsmx2-fi.so and b/ofi/lib/prov/libpsmx2-fi.so differ diff --git a/ofi/lib/prov/librxm-fi.so b/ofi/lib/prov/librxm-fi.so index 5f3a5eb2b..989b2e7f7 100755 Binary files a/ofi/lib/prov/librxm-fi.so and b/ofi/lib/prov/librxm-fi.so differ diff --git a/ofi/lib/prov/libshm-fi.so b/ofi/lib/prov/libshm-fi.so index 3cc3e8885..7128a4f80 100755 Binary files a/ofi/lib/prov/libshm-fi.so and b/ofi/lib/prov/libshm-fi.so differ diff --git a/ofi/lib/prov/libsockets-fi.so b/ofi/lib/prov/libsockets-fi.so index a0955552b..44ee76c31 100755 Binary files a/ofi/lib/prov/libsockets-fi.so and b/ofi/lib/prov/libsockets-fi.so differ diff --git a/ofi/lib/prov/libtcp-fi.so b/ofi/lib/prov/libtcp-fi.so index 741e90301..fbeaf72cb 100755 Binary files a/ofi/lib/prov/libtcp-fi.so and b/ofi/lib/prov/libtcp-fi.so differ diff --git a/ofi/lib/prov/libverbs-fi.so b/ofi/lib/prov/libverbs-fi.so index fb1dc7cd4..0803b1680 100755 Binary files a/ofi/lib/prov/libverbs-fi.so and b/ofi/lib/prov/libverbs-fi.so differ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 737ab770c..d5a25c0dd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,8 +19,6 @@ set (EXTENSIONS_SRC) if (CCL_ENABLE_SYCL) list (APPEND EXTENSIONS_SRC - ccl_cpp_gpu_api.cpp - native_device_api/l0/utils.cpp native_device_api/sycl/export.cpp native_device_api/interop_utils.cpp @@ -32,7 +30,6 @@ endif(CCL_ENABLE_SYCL) if (MULTI_GPU_SUPPORT) list (APPEND EXTENSIONS_SRC - ccl_cpp_gpu_api.cpp ccl_gpu_modules.cpp ccl_cpp_utils.cpp @@ -82,15 +79,16 @@ list (APPEND EXTENSIONS_SRC endif(MULTI_GPU_SUPPORT) set(CCL_SRC - ccl.cpp - ccl_cpp_api.cpp ccl_cpp_communicator.cpp ccl_cpp_environment.cpp ccl_api_functions.cpp ccl_app_api_coll_attr.cpp + ccl_app_api_comm_attr.cpp ccl_app_api_comm_split_attr.cpp ccl_app_api_datatype_attr.cpp + ccl_app_api_kvs_attr.cpp ccl_app_api_event.cpp + ccl_app_api_init_attr.cpp ccl_cpp_kvs.cpp ccl_cpp_device.cpp ccl_cpp_stream.cpp @@ -98,6 +96,10 @@ set(CCL_SRC ccl_cpp_utils.cpp ccl_empty_attr.cpp ccl_empty_coll_attr.cpp + ccl_empty_comm_attr.cpp + ccl_empty_init_attr.cpp + ccl_empty_comm_split_attr.cpp + ccl_empty_kvs_attr.cpp ccl_empty_stream.cpp native_device_api/sycl_l0/export.cpp native_device_api/empty/export.cpp @@ -108,12 +110,12 @@ set(CCL_SRC atl/util/pm/pmi_resizable_rt/pmi_resizable.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.cpp - atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.c + atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.cpp - atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.c - atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.c + atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.cpp + atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.cpp - atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.c + atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.cpp atl/util/pm/pmi_rt/pmi_simple.cpp @@ -165,6 +167,7 @@ set(CCL_SRC sched/entry/coll/coll_entry_helper.cpp sched/entry/entry.cpp sched/entry/factory/chunked_entry_factory.cpp + sched/entry/sycl_entry_helper.cpp exec/exec.cpp exec/thread/base_thread.cpp exec/thread/listener.cpp @@ -173,6 +176,7 @@ set(CCL_SRC fusion/fusion.cpp parallelizer/parallelizer.cpp unordered_coll/unordered_coll.cpp + common/comm/atl_tag.cpp common/comm/comm.cpp common/comm/comm_interface.cpp @@ -182,7 +186,6 @@ set(CCL_SRC common/datatype/datatype.cpp common/device/device.cpp common/event/ccl_event.cpp - common/event/event_internal/event_internal.cpp common/stream/stream.cpp common/env/env.cpp @@ -192,7 +195,9 @@ set(CCL_SRC common/event/impls/native_event.cpp common/request/request.cpp common/utils/spinlock.cpp + common/utils/version.cpp common/utils/yield.cpp + ${EXTENSIONS_SRC}) list(APPEND CCL_INC_DIRS @@ -248,7 +253,7 @@ install(FILES add_library(ccl-static STATIC $) set_target_properties(ccl-static PROPERTIES OUTPUT_NAME ccl) set_target_properties(ccl-static PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) -install(TARGETS ccl-static ARCHIVE DESTINATION ${CCL_INSTALL_LIB}) +install(TARGETS ccl-static ARCHIVE DESTINATION ${CCL_INSTALL_LIB} OPTIONAL) if(MULTI_GPU_SUPPORT) message("Turn on L0 multi-gpu unit tests") diff --git a/src/atl/CMakeLists.txt b/src/atl/CMakeLists.txt index bbf54f42c..e80b7b4c1 100644 --- a/src/atl/CMakeLists.txt +++ b/src/atl/CMakeLists.txt @@ -13,54 +13,54 @@ # See the License for the specific language governing permissions and # limitations under the License. # -#builds ccl_atl_ofi - -add_subdirectory(mpi) - -add_subdirectory(util/pm/pmi_rt/pmi) -add_subdirectory(util/pm/pmi_resizable_rt/pmi_resizable) - -set(OFI_SRC - ofi/atl_ofi.c - util/pm/pmi_rt/pmi_rt.c - util/pm/pmi_resizable_rt/pmi_resizable_rt.c) - -set(COMMON_OFI_INC_DIRS - ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}/util/pm - ${CMAKE_CURRENT_SOURCE_DIR}/util/pm/codec - ${CMAKE_CURRENT_SOURCE_DIR}/util/pm/pmi_rt - ${CMAKE_CURRENT_SOURCE_DIR}/util/pm/pmi_resizable_rt - ${LIBFABRIC_INCLUDE_DIR}) - -#special library that holds objects only -add_library(ccl_atl_ofi-objects OBJECT ${OFI_SRC}) -set_target_properties(ccl_atl_ofi-objects PROPERTIES POSITION_INDEPENDENT_CODE 1) -target_include_directories(ccl_atl_ofi-objects PRIVATE ${COMMON_OFI_INC_DIRS}) -target_include_directories(ccl_atl_ofi-objects PRIVATE $) -target_include_directories(ccl_atl_ofi-objects PRIVATE $) - -#add library search directory - -#shared -add_library(ccl_atl_ofi SHARED $) -target_include_directories(ccl_atl_ofi PRIVATE ${COMMON_OFI_INC_DIRS}) - -target_link_libraries(ccl_atl_ofi PRIVATE pmi) -target_link_libraries(ccl_atl_ofi PRIVATE resizable_pmi) -target_link_libraries(ccl_atl_ofi PRIVATE fabric m) - -if (NOT LIB_ATL_OFI_SO_VERSION AND NOT LIB_ATL_OFI_MAJOR_VERSION) - set_target_properties(ccl_atl_ofi PROPERTIES VERSION 1 SOVERSION 1.0) -else() - set_target_properties(ccl_atl_ofi PROPERTIES VERSION ${LIB_ATL_OFI_SO_VERSION} SOVERSION ${LIB_ATL_OFI_MAJOR_VERSION}) -endif() - -set_target_properties(ccl_atl_ofi PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) -install(TARGETS ccl_atl_ofi LIBRARY DESTINATION ${CCL_INSTALL_LIB}) - -#static -add_library(ccl_atl_ofi-static STATIC $) -set_target_properties(ccl_atl_ofi-static PROPERTIES OUTPUT_NAME ccl_atl_ofi) -set_target_properties(ccl_atl_ofi-static PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) -install(TARGETS ccl_atl_ofi-static ARCHIVE DESTINATION ${CCL_INSTALL_LIB}) +#builds ccl_atl_ofi + +add_subdirectory(mpi) + +add_subdirectory(util/pm/pmi_rt/pmi) +add_subdirectory(util/pm/pmi_resizable_rt/pmi_resizable) + +set(OFI_SRC + ofi/atl_ofi.c + util/pm/pmi_rt/pmi_rt.c + util/pm/pmi_resizable_rt/pmi_resizable_rt.c) + +set(COMMON_OFI_INC_DIRS + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/util/pm + ${CMAKE_CURRENT_SOURCE_DIR}/util/pm/codec + ${CMAKE_CURRENT_SOURCE_DIR}/util/pm/pmi_rt + ${CMAKE_CURRENT_SOURCE_DIR}/util/pm/pmi_resizable_rt + ${LIBFABRIC_INCLUDE_DIR}) + +#special library that holds objects only +add_library(ccl_atl_ofi-objects OBJECT ${OFI_SRC}) +set_target_properties(ccl_atl_ofi-objects PROPERTIES POSITION_INDEPENDENT_CODE 1) +target_include_directories(ccl_atl_ofi-objects PRIVATE ${COMMON_OFI_INC_DIRS}) +target_include_directories(ccl_atl_ofi-objects PRIVATE $) +target_include_directories(ccl_atl_ofi-objects PRIVATE $) + +#add library search directory + +#shared +add_library(ccl_atl_ofi SHARED $) +target_include_directories(ccl_atl_ofi PRIVATE ${COMMON_OFI_INC_DIRS}) + +target_link_libraries(ccl_atl_ofi PRIVATE pmi) +target_link_libraries(ccl_atl_ofi PRIVATE resizable_pmi) +target_link_libraries(ccl_atl_ofi PRIVATE fabric m) + +if (NOT LIB_ATL_OFI_SO_VERSION AND NOT LIB_ATL_OFI_MAJOR_VERSION) + set_target_properties(ccl_atl_ofi PROPERTIES VERSION 1 SOVERSION 1.0) +else() + set_target_properties(ccl_atl_ofi PROPERTIES VERSION ${LIB_ATL_OFI_SO_VERSION} SOVERSION ${LIB_ATL_OFI_MAJOR_VERSION}) +endif() + +set_target_properties(ccl_atl_ofi PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) +install(TARGETS ccl_atl_ofi LIBRARY DESTINATION ${CCL_INSTALL_LIB}) + +#static +add_library(ccl_atl_ofi-static STATIC $) +set_target_properties(ccl_atl_ofi-static PROPERTIES OUTPUT_NAME ccl_atl_ofi) +set_target_properties(ccl_atl_ofi-static PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) +install(TARGETS ccl_atl_ofi-static ARCHIVE DESTINATION ${CCL_INSTALL_LIB}) diff --git a/src/atl/atl.cpp b/src/atl/atl.cpp index 71776f9db..5b95447d0 100644 --- a/src/atl/atl.cpp +++ b/src/atl/atl.cpp @@ -26,7 +26,7 @@ #define ATL_LIB_PREFIX "libccl_atl_" static int initialized = 0; -static int is_main_addr_reserv = 0; +static int should_reserve_addr = 0; static int atl_lib_filter(const struct dirent* entry) { size_t entry_len = strlen(entry->d_name); @@ -106,8 +106,8 @@ static void atl_ini_dir(const char* transport_name, continue; } - if (is_main_addr_reserv) { - ret = transport.main_addr_reserv(const_cast(main_addr)); + if (should_reserve_addr) { + ret = transport.reserve_addr(const_cast(main_addr)); } else { ret = transport.init(argc, argv, attr, ctx, main_addr); @@ -228,8 +228,8 @@ atl_status_t atl_init(const char* transport_name, return ATL_STATUS_FAILURE; } -void atl_main_addr_reserv(char* main_addr) { - is_main_addr_reserv = 1; +void atl_main_addr_reserve(char* main_addr) { + should_reserve_addr = 1; atl_init("ofi", NULL, NULL, NULL, NULL, main_addr); - is_main_addr_reserv = 0; + should_reserve_addr = 0; } diff --git a/src/atl/atl.h b/src/atl/atl.h index 006bf3a48..2fa6a4e93 100644 --- a/src/atl/atl.h +++ b/src/atl/atl.h @@ -29,7 +29,7 @@ atl_status_t atl_init(const char* transport_name, atl_ctx_t** ctx, const char* main_addr); -void atl_main_addr_reserv(char* main_addr); +void atl_main_addr_reserve(char* main_addr); static inline atl_status_t atl_finalize(atl_ctx_t* ctx) { return ctx->ops->finalize(ctx); @@ -70,7 +70,7 @@ static inline atl_status_t atl_mr_dereg(atl_ctx_t* ctx, atl_mr_t* mr) { static inline atl_status_t atl_ep_send(atl_ep_t* ep, const void* buf, size_t len, - size_t dst_proc_idx, + int dst_proc_idx, uint64_t tag, atl_req_t* req) { return ep->p2p_ops->send(ep, buf, len, dst_proc_idx, tag, req); @@ -79,14 +79,14 @@ static inline atl_status_t atl_ep_send(atl_ep_t* ep, static inline atl_status_t atl_ep_recv(atl_ep_t* ep, void* buf, size_t len, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, atl_req_t* req) { return ep->p2p_ops->recv(ep, buf, len, src_proc_idx, tag, req); } static inline atl_status_t atl_ep_probe(atl_ep_t* ep, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, int* found, size_t* recv_len) { @@ -140,7 +140,7 @@ static inline atl_status_t atl_ep_barrier(atl_ep_t* ep, atl_req_t* req) { static inline atl_status_t atl_ep_bcast(atl_ep_t* ep, void* buf, size_t len, - size_t root, + int root, atl_req_t* req) { return ep->coll_ops->bcast(ep, buf, len, root, req); } @@ -149,7 +149,7 @@ static inline atl_status_t atl_ep_reduce(atl_ep_t* ep, const void* send_buf, void* recv_buf, size_t len, - size_t root, + int root, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) { @@ -162,7 +162,7 @@ static inline atl_status_t atl_ep_read(atl_ep_t* ep, atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) { return ep->rma_ops->read(ep, buf, len, mr, addr, remote_key, dst_proc_idx, req); } @@ -173,7 +173,7 @@ static inline atl_status_t atl_ep_write(atl_ep_t* ep, atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) { return ep->rma_ops->write(ep, buf, len, mr, addr, remote_key, dst_proc_idx, req); } @@ -226,19 +226,19 @@ class iatl { virtual atl_status_t atl_ep_send(atl_ep_t* ep, const void* buf, size_t len, - size_t dst_proc_idx, + int dst_proc_idx, uint64_t tag, atl_req_t* req) = 0; virtual atl_status_t atl_ep_recv(atl_ep_t* ep, void* buf, size_t len, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, atl_req_t* req) = 0; virtual atl_status_t atl_ep_probe(atl_ep_t* ep, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, int* found, size_t* recv_len) = 0; @@ -279,14 +279,14 @@ class iatl { virtual atl_status_t atl_ep_bcast(atl_ep_t* ep, void* buf, size_t len, - size_t root, + int root, atl_req_t* req) = 0; virtual atl_status_t atl_ep_reduce(atl_ep_t* ep, const void* send_buf, void* recv_buf, size_t len, - size_t root, + int root, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) = 0; @@ -305,7 +305,7 @@ class iatl { atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) = 0; virtual atl_status_t atl_ep_write(atl_ep_t* ep, @@ -314,7 +314,7 @@ class iatl { atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) = 0; virtual atl_status_t atl_ep_wait(atl_ep_t* ep, atl_req_t* req) = 0; @@ -326,5 +326,6 @@ class iatl { virtual atl_status_t atl_ep_poll(atl_ep_t* ep) = 0; virtual atl_status_t atl_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* req) = 0; + virtual bool is_inited() = 0; }; #endif diff --git a/src/atl/atl_def.h b/src/atl/atl_def.h index 1d5f5cc9a..b068b7019 100644 --- a/src/atl/atl_def.h +++ b/src/atl/atl_def.h @@ -54,7 +54,7 @@ typedef enum { ATL_PROGRESS_POLL, ATL_PROGRESS_CHECK } atl_progress_mode_t; typedef enum { ATL_RA_WAIT, ATL_RA_RUN, ATL_RA_FINALIZE } atl_resize_action_t; -typedef atl_resize_action_t (*atl_resize_fn_t)(size_t size); +typedef atl_resize_action_t (*atl_resize_fn_t)(int size); typedef enum { ATL_STATUS_SUCCESS, @@ -116,10 +116,10 @@ typedef struct { } atl_mr_t; typedef struct { - size_t global_idx; - size_t global_count; - size_t local_idx; - size_t local_count; + int global_idx; + int global_count; + int local_idx; + int local_count; } atl_proc_coord_t; typedef struct { @@ -132,7 +132,7 @@ typedef struct { const char* name; atl_status_t ( *init)(int* argc, char*** argv, atl_attr_t* attr, atl_ctx_t** ctx, const char* main_addr); - atl_status_t (*main_addr_reserv)(char* main_addr); + atl_status_t (*reserve_addr)(char* main_addr); } atl_transport_t; typedef struct { @@ -165,17 +165,13 @@ typedef struct { atl_status_t (*send)(atl_ep_t* ep, const void* buf, size_t len, - size_t dst_proc_idx, - uint64_t tag, - atl_req_t* req); - atl_status_t (*recv)(atl_ep_t* ep, - void* buf, - size_t len, - size_t src_proc_idx, + int dst_proc_idx, uint64_t tag, atl_req_t* req); atl_status_t ( - *probe)(atl_ep_t* ep, size_t src_proc_idx, uint64_t tag, int* found, size_t* recv_len); + *recv)(atl_ep_t* ep, void* buf, size_t len, int src_proc_idx, uint64_t tag, atl_req_t* req); + atl_status_t ( + *probe)(atl_ep_t* ep, int src_proc_idx, uint64_t tag, int* found, size_t* recv_len); } atl_p2p_ops_t; typedef struct { @@ -205,12 +201,12 @@ typedef struct { const int* recv_offsets, atl_req_t* req); atl_status_t (*barrier)(atl_ep_t* ep, atl_req_t* req); - atl_status_t (*bcast)(atl_ep_t* ep, void* buf, size_t len, size_t root, atl_req_t* req); + atl_status_t (*bcast)(atl_ep_t* ep, void* buf, size_t len, int root, atl_req_t* req); atl_status_t (*reduce)(atl_ep_t* ep, const void* send_buf, void* recv_buf, size_t count, - size_t root, + int root, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req); @@ -230,7 +226,7 @@ typedef struct { atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req); atl_status_t (*write)(atl_ep_t* ep, const void* buf, @@ -238,7 +234,7 @@ typedef struct { atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req); } atl_rma_ops_t; diff --git a/src/atl/atl_wrapper.cpp b/src/atl/atl_wrapper.cpp index 7f89d24ba..df2878ad8 100644 --- a/src/atl/atl_wrapper.cpp +++ b/src/atl/atl_wrapper.cpp @@ -35,8 +35,7 @@ atl_attr_t atl_wrapper::attr = { 0 /* extra_ep */ }; -void atl_wrapper::set_internal_env(const atl_attr_t& attr) -{ +void atl_wrapper::set_internal_env(const atl_attr_t& attr) { auto transport_type = ccl::global_data::env().atl_transport; if (transport_type == ccl_atl_mpi) @@ -46,12 +45,10 @@ void atl_wrapper::set_internal_env(const atl_attr_t& attr) } atl_wrapper::atl_wrapper() { - auto transport_type = ccl::global_data::env().atl_transport; char* pm_type_str; - switch (transport_type) - { + switch (transport_type) { case ccl_atl_ofi: pm_type_str = getenv(PM_TYPE); if (pm_type_str) { @@ -71,24 +68,18 @@ atl_wrapper::atl_wrapper() { } transport = std::shared_ptr(new atl_ofi()); break; - case ccl_atl_mpi: - transport = std::shared_ptr(new atl_mpi()); - break; - default: - LOG_ERROR("Unsupported yet"); - break; + case ccl_atl_mpi: transport = std::shared_ptr(new atl_mpi()); break; + default: LOG_ERROR("Unsupported yet"); break; } init_transport(); } atl_wrapper::atl_wrapper(std::shared_ptr k) { - auto transport_type = ccl::global_data::env().atl_transport; char* pm_type_str; - switch (transport_type) - { + switch (transport_type) { case ccl_atl_ofi: pm_type_str = getenv(PM_TYPE); if (pm_type_str) { @@ -107,70 +98,67 @@ atl_wrapper::atl_wrapper(std::shared_ptr k) { } transport = std::shared_ptr(new atl_ofi()); break; - case ccl_atl_mpi: - transport = std::shared_ptr(new atl_mpi()); - break; - default: - LOG_ERROR("Unsupported yet"); - break; + case ccl_atl_mpi: transport = std::shared_ptr(new atl_mpi()); break; + default: LOG_ERROR("Unsupported yet"); break; } init_transport(); } -atl_wrapper::atl_wrapper(size_t dev_count, - const std::vector &ranks, +atl_wrapper::atl_wrapper(int total_rank_count, + const std::vector& ranks, std::shared_ptr k) { auto transport_type = ccl::global_data::env().atl_transport; - switch (transport_type) - { - case ccl_atl_ofi: - pmi = std::unique_ptr(new pmi_resizable_simple(dev_count, ranks, k)); + switch (transport_type) { + case ccl_atl_ofi: { + size_t transorts_count = transports.size(); + pmi = std::unique_ptr(new pmi_resizable_simple(total_rank_count, ranks, k)); - if (pmi->get_thread() == 0) { + if (pmi->get_local_thread_idx() == 0) { transports.push_back(std::shared_ptr(new atl_ofi())); } - pmi->pmrt_barrier(); + //TODO: Rework it on barrier + while (transorts_count == transports.size()) { + ccl_yield(ccl::global_data::env().yield_type); + } static std::mutex memory_mutex; { std::lock_guard lock(memory_mutex); transport = transports.back(); } - break; - case ccl_atl_mpi: - transport = std::shared_ptr(new atl_mpi()); - break; - default: - LOG_ERROR("Unsupported yet"); - break; + } break; + case ccl_atl_mpi: transport = std::shared_ptr(new atl_mpi()); break; + default: LOG_ERROR("Unsupported yet"); break; } init_transport(); } void atl_wrapper::init_transport() { - LOG_INFO("init ATL, requested ep_count ", attr.ep_count); - - transport->atl_init(nullptr, nullptr, &attr, nullptr, pmi); + static std::mutex memory_mutex; + { + std::lock_guard lock(memory_mutex); + if (!transport->is_inited()) + transport->atl_init(nullptr, nullptr, &attr, nullptr, pmi); + } eps = transport->atl_get_eps(); tag = std::unique_ptr(new ccl_atl_tag(attr.tag_bits, attr.max_tag)); if (pmi) { - threads_count = pmi->get_threads_count(); - devices_per_rank_count = pmi->get_devices_per_rank_count(); + threads_per_process = pmi->get_threads_per_process(); + ranks_per_process = pmi->get_ranks_per_process(); rank = pmi->get_rank(); size = pmi->get_size(); } else { - threads_count = 1; - devices_per_rank_count = 1; - rank = static_cast(transport.get())->get_rank(); - size = static_cast(transport.get())->get_size(); + threads_per_process = 1; + ranks_per_process = 1; + rank = static_cast(transport.get())->get_rank(); + size = static_cast(transport.get())->get_size(); } - if (rank == 0) - { + if (rank == 0) { tag->print(); LOG_INFO("\n", diff --git a/src/atl/atl_wrapper.h b/src/atl/atl_wrapper.h index 4886a6de8..6891a53f2 100644 --- a/src/atl/atl_wrapper.h +++ b/src/atl/atl_wrapper.h @@ -28,14 +28,13 @@ class atl_wrapper { public: - static void set_internal_env(const atl_attr_t& attr); ~atl_wrapper(); atl_wrapper(); atl_wrapper(std::shared_ptr k); - atl_wrapper(size_t dev_count, - const std::vector& ranks, + atl_wrapper(int total_rank_count, + const std::vector& ranks, std::shared_ptr k); // atl_status_t @@ -46,11 +45,11 @@ class atl_wrapper { // return transport->atl_init(argc, argv, att, main_addr, pmi); // } - atl_status_t atl_main_addr_reserv(char* main_addr) { + atl_status_t atl_main_addr_reserve(char* main_addr) { if (!pmi) return ATL_STATUS_UNSUPPORTED; - return pmi->pmrt_main_addr_reserv(main_addr); + return pmi->pmrt_main_addr_reserve(main_addr); ; } @@ -98,7 +97,7 @@ class atl_wrapper { atl_status_t atl_ep_send(size_t ep_idx, const void* buf, size_t len, - size_t dst_proc_idx, + int dst_proc_idx, uint64_t tag, atl_req_t* req) { return transport->atl_ep_send(eps[ep_idx], buf, len, dst_proc_idx, tag, req); @@ -107,14 +106,14 @@ class atl_wrapper { atl_status_t atl_ep_recv(size_t ep_idx, void* buf, size_t len, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, atl_req_t* req) { return transport->atl_ep_recv(eps[ep_idx], buf, len, src_proc_idx, tag, req); } atl_status_t atl_ep_probe(size_t ep_idx, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, int* found, size_t* recv_len) { @@ -166,7 +165,7 @@ class atl_wrapper { return transport->atl_ep_barrier(eps[ep_idx], req); } - atl_status_t atl_ep_bcast(size_t ep_idx, void* buf, size_t len, size_t root, atl_req_t* req) { + atl_status_t atl_ep_bcast(size_t ep_idx, void* buf, size_t len, int root, atl_req_t* req) { return transport->atl_ep_bcast(eps[ep_idx], buf, len, root, req); } @@ -174,7 +173,7 @@ class atl_wrapper { const void* send_buf, void* recv_buf, size_t len, - size_t root, + int root, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) { @@ -188,7 +187,8 @@ class atl_wrapper { atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) { - return transport->atl_ep_reduce_scatter(eps[ep_idx], send_buf, recv_buf, recv_len, dtype, op, req); + return transport->atl_ep_reduce_scatter( + eps[ep_idx], send_buf, recv_buf, recv_len, dtype, op, req); } atl_status_t atl_ep_read(size_t ep_idx, @@ -197,7 +197,7 @@ class atl_wrapper { atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) { return transport->atl_ep_read( eps[ep_idx], buf, len, mr, addr, remote_key, dst_proc_idx, req); @@ -209,7 +209,7 @@ class atl_wrapper { atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) { return transport->atl_ep_write( eps[ep_idx], buf, len, mr, addr, remote_key, dst_proc_idx, req); @@ -235,22 +235,30 @@ class atl_wrapper { return transport->atl_ep_check(eps[ep_idx], is_completed, req); } - size_t get_threads_count() { - return threads_count; + size_t get_threads_per_process() { + return threads_per_process; } - size_t get_devices_per_rank_count() { - return devices_per_rank_count; + size_t get_ranks_per_process() { + return ranks_per_process; } - size_t get_rank() { + int get_rank() { return rank; } - size_t get_size() { + int get_size() { return size; } + /* + * TODO: Temporary change. + * Need to define correct to unique id + */ + size_t get_id() { + return 0; + } + /* static ATL attr for all transport instances actual values generated by executor */ static atl_attr_t attr; @@ -258,15 +266,15 @@ class atl_wrapper { std::unique_ptr tag; private: + int rank; + int size; + + size_t threads_per_process; + size_t ranks_per_process; std::shared_ptr transport; std::unique_ptr pmi; - - atl_ep_t** eps = nullptr; - size_t threads_count; - size_t devices_per_rank_count; - size_t rank; - size_t size; + void init_transport(); }; diff --git a/src/atl/mpi/CMakeLists.txt b/src/atl/mpi/CMakeLists.txt index decf51c3e..ab733cbed 100644 --- a/src/atl/mpi/CMakeLists.txt +++ b/src/atl/mpi/CMakeLists.txt @@ -13,42 +13,42 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -set(MPI_SRC - atl_mpi.c) - -set(COMMON_MPI_INC_DIRS - ${CMAKE_CURRENT_SOURCE_DIR}/../ - ${PROJECT_SOURCE_DIR}/mpi/include/ - ${PROJECT_SOURCE_DIR}/src/) - -# special library that holds objects only -add_library(ccl_atl_mpi-objects OBJECT ${MPI_SRC}) -set_target_properties(ccl_atl_mpi-objects PROPERTIES POSITION_INDEPENDENT_CODE 1) -target_include_directories(ccl_atl_mpi-objects PRIVATE ${COMMON_MPI_INC_DIRS}) - -# add library search directory -link_directories(${PROJECT_SOURCE_DIR}/mpi/lib) - -# shared -add_library(ccl_atl_mpi SHARED $) -target_include_directories(ccl_atl_mpi PRIVATE ${COMMON_MPI_INC_DIRS}) - -target_link_libraries(ccl_atl_mpi PRIVATE mpi) - -# link with release_mt libmpi.so for oneAPI Base toolkit -set_target_properties(ccl_atl_mpi PROPERTIES LINK_FLAGS "-Wl,-rpath,../../../../mpi/latest/lib/release_mt/") -if (NOT LIB_ATL_MPI_SO_VERSION AND NOT LIB_ATL_MPI_MAJOR_VERSION) - set_target_properties(ccl_atl_mpi PROPERTIES VERSION 1 SOVERSION 1.0) -else() - set_target_properties(ccl_atl_mpi PROPERTIES VERSION ${LIB_ATL_MPI_SO_VERSION} SOVERSION ${LIB_ATL_MPI_MAJOR_VERSION}) -endif() - -set_target_properties(ccl_atl_mpi PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) -install(TARGETS ccl_atl_mpi LIBRARY DESTINATION ${CCL_INSTALL_LIB}) - -# static -add_library(ccl_atl_mpi-static STATIC $) -set_target_properties(ccl_atl_mpi-static PROPERTIES OUTPUT_NAME ccl_atl_mpi) -set_target_properties(ccl_atl_mpi-static PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) -install(TARGETS ccl_atl_mpi-static ARCHIVE DESTINATION ${CCL_INSTALL_LIB}) + +set(MPI_SRC + atl_mpi.c) + +set(COMMON_MPI_INC_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/../ + ${PROJECT_SOURCE_DIR}/mpi/include/ + ${PROJECT_SOURCE_DIR}/src/) + +# special library that holds objects only +add_library(ccl_atl_mpi-objects OBJECT ${MPI_SRC}) +set_target_properties(ccl_atl_mpi-objects PROPERTIES POSITION_INDEPENDENT_CODE 1) +target_include_directories(ccl_atl_mpi-objects PRIVATE ${COMMON_MPI_INC_DIRS}) + +# add library search directory +link_directories(${PROJECT_SOURCE_DIR}/mpi/lib) + +# shared +add_library(ccl_atl_mpi SHARED $) +target_include_directories(ccl_atl_mpi PRIVATE ${COMMON_MPI_INC_DIRS}) + +target_link_libraries(ccl_atl_mpi PRIVATE mpi) + +# link with release_mt libmpi.so for oneAPI Base toolkit +set_target_properties(ccl_atl_mpi PROPERTIES LINK_FLAGS "-Wl,-rpath,../../../../mpi/latest/lib/release_mt/") +if (NOT LIB_ATL_MPI_SO_VERSION AND NOT LIB_ATL_MPI_MAJOR_VERSION) + set_target_properties(ccl_atl_mpi PROPERTIES VERSION 1 SOVERSION 1.0) +else() + set_target_properties(ccl_atl_mpi PROPERTIES VERSION ${LIB_ATL_MPI_SO_VERSION} SOVERSION ${LIB_ATL_MPI_MAJOR_VERSION}) +endif() + +set_target_properties(ccl_atl_mpi PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) +install(TARGETS ccl_atl_mpi LIBRARY DESTINATION ${CCL_INSTALL_LIB}) + +# static +add_library(ccl_atl_mpi-static STATIC $) +set_target_properties(ccl_atl_mpi-static PROPERTIES OUTPUT_NAME ccl_atl_mpi) +set_target_properties(ccl_atl_mpi-static PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) +install(TARGETS ccl_atl_mpi-static ARCHIVE DESTINATION ${CCL_INSTALL_LIB}) diff --git a/src/atl/mpi/atl_mpi.c b/src/atl/mpi/atl_mpi.c index 9c530d1aa..65570c124 100644 --- a/src/atl/mpi/atl_mpi.c +++ b/src/atl/mpi/atl_mpi.c @@ -182,9 +182,9 @@ static inline void atl_mpi_check_op_params(void* in_buf, } static void INLINE_TARGET_ATTRIBUTE_ALL atl_mpi_bf16_base_op(void* in, - void* inout, - int* length, - ccl_bf16_reduction_func_ptr op) { + void* inout, + int* length, + ccl_bf16_reduction_func_ptr op) { unsigned short* in_buf = (unsigned short*)in; unsigned short* inout_buf = (unsigned short*)inout; @@ -194,33 +194,33 @@ static void INLINE_TARGET_ATTRIBUTE_ALL atl_mpi_bf16_base_op(void* in, // MPI BF16 operation definitions static void TARGET_ATTRIBUTE_ALL atl_mpi_bf16_sum_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { + void* inout, + int* length, + MPI_Datatype* datatype) { atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); atl_mpi_bf16_base_op(in, inout, length, &sum_wrap); } static void TARGET_ATTRIBUTE_ALL atl_mpi_bf16_prod_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { + void* inout, + int* length, + MPI_Datatype* datatype) { atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); atl_mpi_bf16_base_op(in, inout, length, &prod_wrap); } static void TARGET_ATTRIBUTE_ALL atl_mpi_bf16_min_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { + void* inout, + int* length, + MPI_Datatype* datatype) { atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); atl_mpi_bf16_base_op(in, inout, length, &min_wrap); } static void TARGET_ATTRIBUTE_ALL atl_mpi_bf16_max_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { + void* inout, + int* length, + MPI_Datatype* datatype) { atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); atl_mpi_bf16_base_op(in, inout, length, &max_wrap); } @@ -509,7 +509,6 @@ atl_mpi_lib_type_t atl_mpi_get_lib_type() { } atl_status_t atl_mpi_set_env(const atl_attr_t& attr) { - char mpi_ep_count_str[EP_IDX_MAX_STR_LEN] = { 0 }; /* we have endpoints on MPI and ATL levels */ @@ -598,7 +597,7 @@ static atl_status_t atl_mpi_mr_dereg(atl_ctx_t* ctx, atl_mr_t* mr) { static atl_status_t atl_mpi_ep_send(atl_ep_t* ep, const void* buf, size_t len, - size_t dest_proc_idx, + int dst_proc_idx, uint64_t tag, atl_req_t* req) { atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); @@ -606,7 +605,7 @@ static atl_status_t atl_mpi_ep_send(atl_ep_t* ep, mpi_req->comp_state = ATL_MPI_COMP_POSTED; int ret = MPI_Isend( - buf, len, MPI_CHAR, dest_proc_idx, (int)tag, mpi_ep->mpi_comm, &mpi_req->native_req); + buf, len, MPI_CHAR, dst_proc_idx, (int)tag, mpi_ep->mpi_comm, &mpi_req->native_req); #if 0 //#ifdef ENABLE_DEBUG @@ -641,7 +640,7 @@ static atl_status_t atl_mpi_ep_send(atl_ep_t* ep, static atl_status_t atl_mpi_ep_recv(atl_ep_t* ep, void* buf, size_t len, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, atl_req_t* req) { atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); @@ -682,7 +681,7 @@ static atl_status_t atl_mpi_ep_recv(atl_ep_t* ep, } static atl_status_t atl_mpi_ep_probe(atl_ep_t* ep, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, int* found, size_t* recv_len) { @@ -716,8 +715,7 @@ static atl_status_t atl_mpi_ep_allgatherv(atl_ep_t* ep, atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - if (global_data.sync_coll) - { + if (global_data.sync_coll) { ret = MPI_Allgatherv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, send_len, MPI_CHAR, @@ -729,8 +727,7 @@ static atl_status_t atl_mpi_ep_allgatherv(atl_ep_t* ep, mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; mpi_req->native_req = MPI_REQUEST_NULL; } - else - { + else { ret = MPI_Iallgatherv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, send_len, MPI_CHAR, @@ -761,8 +758,7 @@ static atl_status_t atl_mpi_ep_allreduce(atl_ep_t* ep, MPI_Datatype mpi_dtype = atl2mpi_dtype(dtype); MPI_Op mpi_op = atl2mpi_op(op, mpi_dtype); - if (global_data.sync_coll) - { + if (global_data.sync_coll) { ret = MPI_Allreduce((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, recv_buf, count, @@ -772,8 +768,7 @@ static atl_status_t atl_mpi_ep_allreduce(atl_ep_t* ep, mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; mpi_req->native_req = MPI_REQUEST_NULL; } - else - { + else { ret = MPI_Iallreduce((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, recv_buf, count, @@ -825,8 +820,7 @@ static atl_status_t atl_mpi_ep_alltoall(atl_ep_t* ep, atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - if (global_data.sync_coll) - { + if (global_data.sync_coll) { ret = MPI_Alltoall((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, len, MPI_CHAR, @@ -837,8 +831,7 @@ static atl_status_t atl_mpi_ep_alltoall(atl_ep_t* ep, mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; mpi_req->native_req = MPI_REQUEST_NULL; } - else - { + else { ret = MPI_Ialltoall((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, len, MPI_CHAR, @@ -866,8 +859,7 @@ static atl_status_t atl_mpi_ep_alltoallv(atl_ep_t* ep, atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - if (global_data.sync_coll) - { + if (global_data.sync_coll) { ret = MPI_Alltoallv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, send_lens, send_offsets, @@ -880,8 +872,7 @@ static atl_status_t atl_mpi_ep_alltoallv(atl_ep_t* ep, mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; mpi_req->native_req = MPI_REQUEST_NULL; } - else - { + else { ret = MPI_Ialltoallv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, send_lens, send_offsets, @@ -899,20 +890,17 @@ static atl_status_t atl_mpi_ep_alltoallv(atl_ep_t* ep, } static atl_status_t atl_mpi_ep_barrier(atl_ep_t* ep, atl_req_t* req) { - int ret = MPI_SUCCESS; atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - if (global_data.sync_coll) - { + if (global_data.sync_coll) { ret = MPI_Barrier(mpi_ep->mpi_comm); mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; mpi_req->native_req = MPI_REQUEST_NULL; } - else - { + else { ret = MPI_Ibarrier(mpi_ep->mpi_comm, &mpi_req->native_req); mpi_req->comp_state = ATL_MPI_COMP_POSTED; } @@ -923,21 +911,19 @@ static atl_status_t atl_mpi_ep_barrier(atl_ep_t* ep, atl_req_t* req) { static atl_status_t atl_mpi_ep_bcast(atl_ep_t* ep, void* buf, size_t len, - size_t root, + int root, atl_req_t* req) { int ret = MPI_SUCCESS; atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - if (global_data.sync_coll) - { + if (global_data.sync_coll) { ret = MPI_Bcast(buf, len, MPI_CHAR, root, mpi_ep->mpi_comm); mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; mpi_req->native_req = MPI_REQUEST_NULL; } - else - { + else { ret = MPI_Ibcast(buf, len, MPI_CHAR, root, mpi_ep->mpi_comm, &mpi_req->native_req); mpi_req->comp_state = ATL_MPI_COMP_POSTED; } @@ -949,7 +935,7 @@ static atl_status_t atl_mpi_ep_reduce(atl_ep_t* ep, const void* send_buf, void* recv_buf, size_t count, - size_t root, + int root, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) { @@ -958,25 +944,23 @@ static atl_status_t atl_mpi_ep_reduce(atl_ep_t* ep, atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - size_t my_proc_idx = ep->ctx->coord.global_idx; + int my_proc_idx = ep->ctx->coord.global_idx; MPI_Datatype mpi_dtype = atl2mpi_dtype(dtype); MPI_Op mpi_op = atl2mpi_op(op, mpi_dtype); - if (global_data.sync_coll) - { + if (global_data.sync_coll) { ret = MPI_Reduce( - (send_buf && (send_buf == recv_buf) && (root == my_proc_idx)) ? MPI_IN_PLACE : send_buf, - recv_buf, - count, - mpi_dtype, - mpi_op, - root, - mpi_ep->mpi_comm); + (send_buf && (send_buf == recv_buf) && (root == my_proc_idx)) ? MPI_IN_PLACE : send_buf, + recv_buf, + count, + mpi_dtype, + mpi_op, + root, + mpi_ep->mpi_comm); mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; mpi_req->native_req = MPI_REQUEST_NULL; } - else - { + else { ret = MPI_Ireduce( (send_buf && (send_buf == recv_buf) && (root == my_proc_idx)) ? MPI_IN_PLACE : send_buf, recv_buf, @@ -1007,20 +991,18 @@ static atl_status_t atl_mpi_ep_reduce_scatter(atl_ep_t* ep, MPI_Datatype mpi_dtype = atl2mpi_dtype(dtype); MPI_Op mpi_op = atl2mpi_op(op, mpi_dtype); - if (global_data.sync_coll) - { - ret = MPI_Reduce_scatter_block( - (send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - recv_buf, - recv_count, - mpi_dtype, - mpi_op, - mpi_ep->mpi_comm); + if (global_data.sync_coll) { + ret = + MPI_Reduce_scatter_block((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + recv_buf, + recv_count, + mpi_dtype, + mpi_op, + mpi_ep->mpi_comm); mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; mpi_req->native_req = MPI_REQUEST_NULL; } - else - { + else { ret = MPI_Ireduce_scatter_block( (send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, recv_buf, @@ -1041,7 +1023,7 @@ static atl_status_t atl_mpi_ep_read(atl_ep_t* ep, atl_mr_t* mr, uint64_t addr, uintptr_t r_key, - size_t dest_proc_idx, + int dst_proc_idx, atl_req_t* req) { return ATL_STATUS_UNSUPPORTED; } @@ -1052,7 +1034,7 @@ static atl_status_t atl_mpi_ep_write(atl_ep_t* ep, atl_mr_t* mr, uint64_t addr, uintptr_t r_key, - size_t dest_proc_idx, + int dst_proc_idx, atl_req_t* req) { return ATL_STATUS_UNSUPPORTED; } @@ -1144,7 +1126,6 @@ static atl_comp_ops_t atl_mpi_ep_comp_ops = { .wait = atl_mpi_ep_wait, .check = atl_mpi_ep_check }; static atl_status_t atl_mpi_ep_init(atl_mpi_ctx_t* mpi_ctx, size_t idx, atl_ep_t** ep) { - int ret; ssize_t mpi_ep_idx = idx; @@ -1160,8 +1141,9 @@ static atl_status_t atl_mpi_ep_init(atl_mpi_ctx_t* mpi_ctx, size_t idx, atl_ep_t MPI_Info_create(&info); char mpi_ep_idx_str[EP_IDX_MAX_STR_LEN]; - - if (global_data.extra_ep) mpi_ep_idx += global_data.extra_ep; + + if (global_data.extra_ep) + mpi_ep_idx += global_data.extra_ep; memset(mpi_ep_idx_str, 0, EP_IDX_MAX_STR_LEN); snprintf(mpi_ep_idx_str, EP_IDX_MAX_STR_LEN, "%zu", mpi_ep_idx); @@ -1235,7 +1217,8 @@ static atl_status_t atl_mpi_init(int* argc, ret = MPI_Init_thread(argc, argv, required_thread_level, &provided_thread_level); if (provided_thread_level < required_thread_level) { ATL_MPI_PRINT("unexpected MPI thread level: requested %d, provided %d", - required_thread_level, provided_thread_level); + required_thread_level, + provided_thread_level); goto err_init; } } @@ -1245,8 +1228,9 @@ static atl_status_t atl_mpi_init(int* argc, MPI_Query_thread(&provided_thread_level); if (provided_thread_level < required_thread_level) { ATL_MPI_PRINT("MPI was initialized externaly but with unexpected thread level: " - "requested %d, provided %d", - required_thread_level, provided_thread_level); + "requested %d, provided %d", + required_thread_level, + provided_thread_level); goto err_init; } } @@ -1333,14 +1317,14 @@ static atl_status_t atl_mpi_init(int* argc, return ATL_STATUS_FAILURE; } -atl_status_t atl_mpi_main_addr_reserv(char* main_addr) { +atl_status_t atl_mpi_main_addr_reserve(char* main_addr) { return ATL_STATUS_UNSUPPORTED; } ATL_MPI_INI { atl_transport->name = "mpi"; atl_transport->init = atl_mpi_init; - atl_transport->main_addr_reserv = atl_mpi_main_addr_reserv; + atl_transport->reserve_addr = atl_mpi_main_addr_reserve; return ATL_STATUS_SUCCESS; } #ifdef __cplusplus diff --git a/src/atl/mpi/atl_mpi.cpp b/src/atl/mpi/atl_mpi.cpp index 8f23b32e6..29ec6099d 100644 --- a/src/atl/mpi/atl_mpi.cpp +++ b/src/atl/mpi/atl_mpi.cpp @@ -16,8 +16,7 @@ #include "atl_mpi.h" #include "atl_mpi.c" -atl_status_t atl_mpi::atl_set_env(const atl_attr_t& attr) -{ +atl_status_t atl_mpi::atl_set_env(const atl_attr_t& attr) { return atl_mpi_set_env(attr); } @@ -27,6 +26,7 @@ atl_status_t atl_mpi::atl_init(int* argc, const char* main_addr, std::unique_ptr& pmi) { (void)pmi; + inited = true; return atl_mpi_init(argc, argv, attr, &ctx, main_addr); } @@ -63,7 +63,7 @@ atl_status_t atl_mpi::atl_mr_dereg(atl_mr_t* mr) { atl_status_t atl_mpi::atl_ep_send(atl_ep_t* ep, const void* buf, size_t len, - size_t dst_proc_idx, + int dst_proc_idx, uint64_t tag, atl_req_t* req) { return atl_mpi_ep_send(ep, buf, len, dst_proc_idx, tag, req); @@ -72,14 +72,14 @@ atl_status_t atl_mpi::atl_ep_send(atl_ep_t* ep, atl_status_t atl_mpi::atl_ep_recv(atl_ep_t* ep, void* buf, size_t len, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, atl_req_t* req) { return atl_mpi_ep_recv(ep, buf, len, src_proc_idx, tag, req); } atl_status_t atl_mpi::atl_ep_probe(atl_ep_t* ep, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, int* found, size_t* recv_len) { @@ -130,11 +130,7 @@ atl_status_t atl_mpi::atl_ep_barrier(atl_ep_t* ep, atl_req_t* req) { return atl_mpi_ep_barrier(ep, req); } -atl_status_t atl_mpi::atl_ep_bcast(atl_ep_t* ep, - void* buf, - size_t len, - size_t root, - atl_req_t* req) { +atl_status_t atl_mpi::atl_ep_bcast(atl_ep_t* ep, void* buf, size_t len, int root, atl_req_t* req) { return atl_mpi_ep_bcast(ep, buf, len, root, req); } @@ -142,7 +138,7 @@ atl_status_t atl_mpi::atl_ep_reduce(atl_ep_t* ep, const void* send_buf, void* recv_buf, size_t len, - size_t root, + int root, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) { @@ -165,7 +161,7 @@ atl_status_t atl_mpi::atl_ep_read(atl_ep_t* ep, atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) { return atl_mpi_ep_read(ep, buf, len, mr, addr, remote_key, dst_proc_idx, req); } @@ -176,7 +172,7 @@ atl_status_t atl_mpi::atl_ep_write(atl_ep_t* ep, atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) { return atl_mpi_ep_write(ep, buf, len, mr, addr, remote_key, dst_proc_idx, req); } diff --git a/src/atl/mpi/atl_mpi.h b/src/atl/mpi/atl_mpi.h index a4941f500..a4aab3a47 100644 --- a/src/atl/mpi/atl_mpi.h +++ b/src/atl/mpi/atl_mpi.h @@ -43,19 +43,19 @@ class atl_mpi final : public iatl { atl_status_t atl_ep_send(atl_ep_t* ep, const void* buf, size_t len, - size_t dst_proc_idx, + int dst_proc_idx, uint64_t tag, atl_req_t* req) override; atl_status_t atl_ep_recv(atl_ep_t* ep, void* buf, size_t len, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, atl_req_t* req) override; atl_status_t atl_ep_probe(atl_ep_t* ep, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, int* found, size_t* recv_len) override; @@ -96,14 +96,14 @@ class atl_mpi final : public iatl { atl_status_t atl_ep_bcast(atl_ep_t* ep, void* buf, size_t len, - size_t root, + int root, atl_req_t* req) override; atl_status_t atl_ep_reduce(atl_ep_t* ep, const void* send_buf, void* recv_buf, size_t len, - size_t root, + int root, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) override; @@ -122,7 +122,7 @@ class atl_mpi final : public iatl { atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) override; atl_status_t atl_ep_write(atl_ep_t* ep, @@ -131,7 +131,7 @@ class atl_mpi final : public iatl { atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) override; atl_status_t atl_ep_wait(atl_ep_t* ep, atl_req_t* req) override; @@ -146,14 +146,18 @@ class atl_mpi final : public iatl { atl_status_t atl_finalize() override; - size_t get_rank() { + int get_rank() { return ctx->coord.global_idx; } - size_t get_size() { + int get_size() { return ctx->coord.global_count; } + bool is_inited() override { + return inited; + } private: atl_ctx_t* ctx = nullptr; bool is_finalized{ false }; + bool inited{ false }; }; diff --git a/src/atl/ofi/atl_ofi.c b/src/atl/ofi/atl_ofi.c index 702fb0817..8ca55b11f 100644 --- a/src/atl/ofi/atl_ofi.c +++ b/src/atl/ofi/atl_ofi.c @@ -196,7 +196,7 @@ typedef struct { /* table[0..proc_count][0..ep_count] */ fi_addr_t* addr_table; size_t addr_len; - size_t first_proc_idx; + int first_proc_idx; } atl_ofi_prov_t; @@ -229,16 +229,14 @@ typedef struct { } atl_ofi_req_t; static void atl_ofi_print_coord(atl_proc_coord_t* coord) { - ATL_OFI_DEBUG_PRINT("coord: global [idx %zu, cnt %zu], local [idx %zu, cnt %zu]", + ATL_OFI_DEBUG_PRINT("coord: global [idx %d, cnt %d], local [idx %d, cnt %d]", coord->global_idx, coord->global_count, coord->local_idx, coord->local_count); } -static inline atl_ofi_prov_t* atl_ofi_get_prov(atl_ep_t* ep, - size_t peer_proc_idx, - size_t msg_size) { +static inline atl_ofi_prov_t* atl_ofi_get_prov(atl_ep_t* ep, int peer_proc_idx, size_t msg_size) { size_t prov_idx; atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); @@ -251,8 +249,8 @@ static inline atl_ofi_prov_t* atl_ofi_get_prov(atl_ep_t* ep, ofi_ctx->prov_count); atl_proc_coord_t* coord = &(ep->ctx->coord); - size_t my_node_idx = coord->global_idx / coord->local_count; - size_t peer_node_idx = peer_proc_idx / coord->local_count; + int my_node_idx = coord->global_idx / coord->local_count; + int peer_node_idx = peer_proc_idx / coord->local_count; if ((my_node_idx == peer_node_idx) && (msg_size <= ofi_ctx->provs[ofi_ctx->shm_prov_idx].max_msg_size)) @@ -273,7 +271,7 @@ static inline atl_ofi_prov_t* atl_ofi_get_prov(atl_ep_t* ep, static inline fi_addr_t atl_ofi_get_addr(atl_ctx_t* ctx, atl_ofi_prov_t* prov, - size_t proc_idx, + int proc_idx, size_t ep_idx) { return *(prov->addr_table + ((ctx->ep_count * (proc_idx - prov->first_proc_idx)) + ep_idx)); } @@ -284,12 +282,12 @@ static atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, ipmi* p atl_proc_coord_t* coord = &(ofi_ctx->ctx.coord); atl_status_t ret = ATL_STATUS_SUCCESS; - size_t i; - size_t local_idx = 0, local_count = 0; + int i; + int local_idx = 0, local_count = 0; char* all_hostnames = NULL; char my_hostname[ATL_OFI_MAX_HOSTNAME_LEN] = { 0 }; size_t my_hostname_len = 0; - size_t my_global_proc_idx = coord->global_idx; + int my_global_proc_idx = coord->global_idx; gethostname(my_hostname, ATL_OFI_MAX_HOSTNAME_LEN - 1); my_hostname_len = strlen(my_hostname); @@ -306,7 +304,7 @@ static atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, ipmi* p snprintf(my_hostname + my_hostname_len, ATL_OFI_MAX_HOSTNAME_LEN - my_hostname_len, - "-%zu", + "-%d", my_global_proc_idx); ret = pmi->pmrt_kvs_put((char*)ATL_OFI_HOSTNAME_PM_KEY, @@ -343,9 +341,9 @@ static atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, ipmi* p all_hostnames + i * ATL_OFI_MAX_HOSTNAME_LEN, my_hostname_len + 1 /* including "-" at the end */)) { local_count++; - size_t peer_global_proc_idx; + int peer_global_proc_idx; sscanf(all_hostnames + i * ATL_OFI_MAX_HOSTNAME_LEN + my_hostname_len + 1, - "%zu", + "%d", &peer_global_proc_idx); if (my_global_proc_idx > peer_global_proc_idx) local_idx++; @@ -372,8 +370,9 @@ static atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, atl_ctx_t* ctx = &(ofi_ctx->ctx); atl_ofi_prov_t* prov = &(ofi_ctx->provs[prov_idx]); - atl_status_t ret; - size_t i, j; + atl_status_t ret = ATL_STATUS_SUCCESS; + int i; + size_t j; int insert_count; size_t addr_idx = 0; @@ -382,20 +381,20 @@ static atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, size_t named_ep_count = (prov->sep ? 1 : ctx->ep_count); - size_t local_count = ctx->coord.local_count; - size_t node_idx = ctx->coord.global_idx / local_count; - size_t shm_start_idx = node_idx * local_count; - size_t shm_end_idx = (node_idx + 1) * local_count; + int local_count = ctx->coord.local_count; + int node_idx = ctx->coord.global_idx / local_count; + int shm_start_idx = node_idx * local_count; + int shm_end_idx = (node_idx + 1) * local_count; - ATL_OFI_DEBUG_PRINT("shm_start_idx %zu, shm_end_idx %zu", shm_start_idx, shm_end_idx); + ATL_OFI_DEBUG_PRINT("shm_start_idx %d, shm_end_idx %d", shm_start_idx, shm_end_idx); - size_t proc_count = prov->is_shm ? ctx->coord.local_count : ctx->coord.global_count; + int proc_count = prov->is_shm ? ctx->coord.local_count : ctx->coord.global_count; if (proc_count == 0) return ATL_STATUS_SUCCESS; ATL_OFI_DEBUG_PRINT( - "name %s, is_shm %d, addr_len %zu, local_count %zu, global_count %zu, proc_count %zu", + "name %s, is_shm %d, addr_len %zu, local_count %d, global_count %d, proc_count %d", prov->info->fabric_attr->prov_name, prov->is_shm, prov->addr_len, @@ -407,7 +406,7 @@ static atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, epnames_table_len = prov->addr_len * named_ep_count * proc_count; if (epnames_table_len == 0) { - ATL_OFI_PRINT("epnames_table_len == 0, addr_len %zu, named_ep_count %zu, proc_count %zu", + ATL_OFI_PRINT("epnames_table_len == 0, addr_len %zu, named_ep_count %zu, proc_count %d", prov->addr_len, named_ep_count, proc_count); @@ -438,7 +437,7 @@ static atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, prov->addr_len); if (ret) { - ATL_OFI_PRINT("kvs_get error: ret %d, proc_idx %zu, ep_idx %zu, addr_idx %zu", + ATL_OFI_PRINT("kvs_get error: ret %d, proc_idx %d, ep_idx %zu, addr_idx %zu", ret, i, j, @@ -451,7 +450,7 @@ static atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, } ATL_OFI_DEBUG_PRINT( - "kvs_get: ep_count %zu, proc_count %zu, got %zu", named_ep_count, proc_count, addr_idx); + "kvs_get: ep_count %zu, proc_count %d, got %zu", named_ep_count, proc_count, addr_idx); if (addr_idx != named_ep_count * proc_count) { ATL_OFI_PRINT("unexpected kvs_get results: expected %zu, got %zu", @@ -474,7 +473,7 @@ static atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, insert_count = fi_av_insert( prov->av, epnames_table, named_ep_count * proc_count, prov->addr_table, 0, NULL); - ATL_OFI_DEBUG_PRINT("av_insert: ep_count %zu, proc_count %zu, inserted %d", + ATL_OFI_DEBUG_PRINT("av_insert: ep_count %zu, proc_count %d, inserted %d", named_ep_count, proc_count, insert_count); @@ -495,6 +494,11 @@ static atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, fi_addr_t* table; table = (fi_addr_t*)calloc(1, proc_count * sizeof(fi_addr_t)); + if (table == NULL) { + ATL_OFI_DEBUG_PRINT("Memory allocaion failed"); + ret = ATL_STATUS_FAILURE; + goto err_ep_names; + } memcpy(table, prov->addr_table, proc_count * sizeof(fi_addr_t)); for (i = 0; i < proc_count; i++) { @@ -736,7 +740,6 @@ static int atl_ofi_wait_cancel_cq(struct fid_cq* cq) { } static atl_status_t atl_ofi_prov_ep_init(atl_ofi_prov_t* prov, size_t ep_idx) { - ssize_t ret = 0; struct fi_cq_attr cq_attr; @@ -903,13 +906,16 @@ static atl_status_t atl_ofi_set_env(const atl_attr_t& attr) { } static atl_status_t atl_ofi_adjust_env(atl_ofi_ctx_t* ofi_ctx, const atl_attr_t& attr) { - atl_ofi_set_env(attr); char* prov_env = getenv("FI_PROVIDER"); if (prov_env && strlen(prov_env)) { ofi_ctx->prov_env_copy = (char*)calloc(strlen(prov_env) + 1, sizeof(char)); + if (ofi_ctx->prov_env_copy == NULL) { + ATL_OFI_DEBUG_PRINT("Memory allocaion failed"); + return ATL_STATUS_FAILURE; + } memcpy(ofi_ctx->prov_env_copy, prov_env, strlen(prov_env)); } else @@ -925,8 +931,11 @@ static atl_status_t atl_ofi_adjust_env(atl_ofi_ctx_t* ofi_ctx, const atl_attr_t& (single_prov ? 0 : 1) + /* for delimeter */ 1; /* for terminating null symbol */ - char* prov_env_copy; - prov_env_copy = (char*)calloc(prov_env_copy_size, sizeof(char)); + char* prov_env_copy = (char*)calloc(prov_env_copy_size, sizeof(char)); + if (prov_env_copy == NULL) { + ATL_OFI_DEBUG_PRINT("Memory allocaion failed"); + return ATL_STATUS_FAILURE; + } if (single_prov) snprintf(prov_env_copy, prov_env_copy_size, "%s", ATL_OFI_SHM_PROV_NAME); @@ -1019,7 +1028,7 @@ static atl_status_t atl_ofi_ep_wait(atl_ep_t* ep, atl_req_t* req); static atl_status_t atl_ofi_ep_send(atl_ep_t* ep, const void* buf, size_t len, - size_t dst_proc_idx, + int dst_proc_idx, uint64_t tag, atl_req_t* req) { ssize_t ret; @@ -1054,7 +1063,7 @@ static atl_status_t atl_ofi_ep_send(atl_ep_t* ep, static atl_status_t atl_ofi_ep_recv(atl_ep_t* ep, void* buf, size_t len, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, atl_req_t* req) { ssize_t ret; @@ -1088,7 +1097,7 @@ static atl_status_t atl_ofi_ep_recv(atl_ep_t* ep, } static atl_status_t atl_ofi_ep_probe(atl_ep_t* ep, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, int* found, size_t* recv_len) { @@ -1242,7 +1251,7 @@ static atl_status_t atl_ofi_ep_barrier(atl_ep_t* ep, atl_req_t* req) { static atl_status_t atl_ofi_ep_bcast(atl_ep_t* ep, void* buf, size_t len, - size_t root, + int root, atl_req_t* req) { return ATL_STATUS_UNSUPPORTED; } @@ -1251,7 +1260,7 @@ static atl_status_t atl_ofi_ep_reduce(atl_ep_t* ep, const void* send_buf, void* recv_buf, size_t count, - size_t root, + int root, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) { @@ -1274,7 +1283,7 @@ static atl_status_t atl_ofi_ep_read(atl_ep_t* ep, atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) { ssize_t ret; @@ -1312,7 +1321,7 @@ static atl_status_t atl_ofi_ep_write(atl_ep_t* ep, atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) { ssize_t ret; @@ -1570,7 +1579,7 @@ static atl_status_t atl_ofi_init(int* argc, if (prov_env && !strcmp(prov_env, ATL_OFI_SHM_PROV_NAME)) { ATL_OFI_ASSERT( coord->global_count == coord->local_count, - "shm provider is requested as primary provider but global_count (%zu) != local_count (%zu)", + "shm provider is requested as primary provider but global_count (%d) != local_count (%d)", coord->global_count, coord->local_count); @@ -1886,7 +1895,7 @@ static atl_status_t atl_ofi_init(int* argc, ATL_OFI_INI { atl_transport->name = "ofi"; atl_transport->init = atl_ofi_init; - atl_transport->main_addr_reserv = atl_ofi_main_addr_reserv; + atl_transport->reserve_addr = atl_ofi_main_addr_reserve; return ATL_STATUS_SUCCESS; } #endif diff --git a/src/atl/ofi/atl_ofi.cpp b/src/atl/ofi/atl_ofi.cpp index 1a64d0dc6..29a616550 100644 --- a/src/atl/ofi/atl_ofi.cpp +++ b/src/atl/ofi/atl_ofi.cpp @@ -16,8 +16,7 @@ #include "atl_ofi.h" #include "atl_ofi.c" -atl_status_t atl_ofi::atl_set_env(const atl_attr_t& attr) -{ +atl_status_t atl_ofi::atl_set_env(const atl_attr_t& attr) { return atl_ofi_set_env(attr); } @@ -26,6 +25,7 @@ atl_status_t atl_ofi::atl_init(int* argc, atl_attr_t* attr, const char* main_addr, std::unique_ptr& pmi) { + inited = true; return atl_ofi_init(argc, argv, attr, &ctx, main_addr, pmi.get()); } @@ -62,7 +62,7 @@ atl_status_t atl_ofi::atl_update(std::unique_ptr& pmi) { if (ofi_ctx->prov_count == 1 && ofi_ctx->provs[0].is_shm) { ATL_OFI_ASSERT(coord->global_count == coord->local_count, - "unexpected coord after update: global_count %zu, local_count %zu", + "unexpected coord after update: global_count %d, local_count %d", coord->global_count, coord->local_count); /* TODO: recreate providers */ @@ -104,7 +104,7 @@ atl_status_t atl_ofi::atl_mr_dereg(atl_mr_t* mr) { atl_status_t atl_ofi::atl_ep_send(atl_ep_t* ep, const void* buf, size_t len, - size_t dst_proc_idx, + int dst_proc_idx, uint64_t tag, atl_req_t* req) { return atl_ofi_ep_send(ep, buf, len, dst_proc_idx, tag, req); @@ -113,14 +113,14 @@ atl_status_t atl_ofi::atl_ep_send(atl_ep_t* ep, atl_status_t atl_ofi::atl_ep_recv(atl_ep_t* ep, void* buf, size_t len, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, atl_req_t* req) { return atl_ofi_ep_recv(ep, buf, len, src_proc_idx, tag, req); } atl_status_t atl_ofi::atl_ep_probe(atl_ep_t* ep, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, int* found, size_t* recv_len) { @@ -171,11 +171,7 @@ atl_status_t atl_ofi::atl_ep_barrier(atl_ep_t* ep, atl_req_t* req) { return atl_ofi_ep_barrier(ep, req); } -atl_status_t atl_ofi::atl_ep_bcast(atl_ep_t* ep, - void* buf, - size_t len, - size_t root, - atl_req_t* req) { +atl_status_t atl_ofi::atl_ep_bcast(atl_ep_t* ep, void* buf, size_t len, int root, atl_req_t* req) { return atl_ofi_ep_bcast(ep, buf, len, root, req); } @@ -183,7 +179,7 @@ atl_status_t atl_ofi::atl_ep_reduce(atl_ep_t* ep, const void* send_buf, void* recv_buf, size_t len, - size_t root, + int root, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) { @@ -206,7 +202,7 @@ atl_status_t atl_ofi::atl_ep_read(atl_ep_t* ep, atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) { return atl_ofi_ep_read(ep, buf, len, mr, addr, remote_key, dst_proc_idx, req); } @@ -217,7 +213,7 @@ atl_status_t atl_ofi::atl_ep_write(atl_ep_t* ep, atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) { return atl_ofi_ep_write(ep, buf, len, mr, addr, remote_key, dst_proc_idx, req); } diff --git a/src/atl/ofi/atl_ofi.h b/src/atl/ofi/atl_ofi.h index bba00f10b..fd093ceb7 100644 --- a/src/atl/ofi/atl_ofi.h +++ b/src/atl/ofi/atl_ofi.h @@ -46,19 +46,19 @@ class atl_ofi final : public iatl { atl_status_t atl_ep_send(atl_ep_t* ep, const void* buf, size_t len, - size_t dst_proc_idx, + int dst_proc_idx, uint64_t tag, atl_req_t* req) override; atl_status_t atl_ep_recv(atl_ep_t* ep, void* buf, size_t len, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, atl_req_t* req) override; atl_status_t atl_ep_probe(atl_ep_t* ep, - size_t src_proc_idx, + int src_proc_idx, uint64_t tag, int* found, size_t* recv_len) override; @@ -99,14 +99,14 @@ class atl_ofi final : public iatl { atl_status_t atl_ep_bcast(atl_ep_t* ep, void* buf, size_t len, - size_t root, + int root, atl_req_t* req) override; atl_status_t atl_ep_reduce(atl_ep_t* ep, const void* send_buf, void* recv_buf, size_t len, - size_t root, + int root, atl_datatype_t dtype, atl_reduction_t op, atl_req_t* req) override; @@ -125,7 +125,7 @@ class atl_ofi final : public iatl { atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) override; atl_status_t atl_ep_write(atl_ep_t* ep, @@ -134,7 +134,7 @@ class atl_ofi final : public iatl { atl_mr_t* mr, uint64_t addr, uintptr_t remote_key, - size_t dst_proc_idx, + int dst_proc_idx, atl_req_t* req) override; atl_status_t atl_ep_wait(atl_ep_t* ep, atl_req_t* req) override; @@ -149,7 +149,12 @@ class atl_ofi final : public iatl { atl_status_t atl_finalize() override; + bool is_inited() override { + return inited; + } + private: atl_ctx_t* ctx = nullptr; bool is_finalized{ false }; + bool inited{ false }; }; diff --git a/src/atl/util/pm/pm_rt.h b/src/atl/util/pm/pm_rt.h index 6f4f8facc..8057df881 100644 --- a/src/atl/util/pm/pm_rt.h +++ b/src/atl/util/pm/pm_rt.h @@ -13,180 +13,180 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef PM_RT_H -#define PM_RT_H - -#include "atl_def.h" - -#define PM_TYPE "CCL_PM_TYPE" - -#define PM_RT_VAL_SIMPLE "simple" -#define PM_RT_VAL_RESIZABLE "resizable" - -typedef struct pm_rt_desc pm_rt_desc_t; - -//typedef enum pm_rt_type { -// PM_RT_SIMPLE = 0, -// PM_RT_RESIZABLE = 1, -//} pm_rt_type_t; -// -//static pm_rt_type_t type = PM_RT_SIMPLE; - -typedef struct pm_rt_ops { - void (*finalize)(pm_rt_desc_t *pmrt_desc); - void (*barrier)(pm_rt_desc_t *pmrt_desc); - atl_status_t (*update)(size_t *proc_idx, size_t *proc_count); - atl_status_t (*wait_notification)(void); -} pm_rt_ops_t; - -typedef struct pm_rt_kvs_ops { - atl_status_t (*put)(pm_rt_desc_t *pmrt_desc, - char *kvs_key, - size_t proc_idx, - const void *kvs_val, - size_t kvs_val_len); - atl_status_t (*get)(pm_rt_desc_t *pmrt_desc, - char *kvs_key, - size_t proc_idx, - void *kvs_val, - size_t kvs_val_len); -} pm_rt_kvs_ops_t; - -struct pm_rt_desc { - pm_rt_ops_t *ops; - pm_rt_kvs_ops_t *kvs_ops; -}; - -#if 0 -/* PMI RT */ -atl_status_t pmirt_init(size_t *proc_idx, size_t *procs_num, pm_rt_desc_t **pmrt_desc); -atl_status_t resizable_pmirt_init(size_t *proc_idx, - size_t *proc_count, - pm_rt_desc_t **pmrt_desc, - const char *main_addr); -atl_status_t resizable_pmirt_set_resize_function(atl_resize_fn_t resize_fn); -atl_status_t resizable_pmirt_main_addr_reserv(char *main_addr); - - -static inline int is_pm_resize_enabled() { - if (type == PM_RT_RESIZABLE) - return 1; - return 0; -} - -static inline atl_status_t pmrt_init(size_t *proc_idx, - size_t *procs_num, - pm_rt_desc_t **pmrt_desc, - const char *main_addr) { - char *type_str = getenv(PM_TYPE); - - if (type_str) { - if (strstr(type_str, PM_RT_VAL_SIMPLE)) { - type = PM_RT_SIMPLE; - } - else if (strstr(type_str, PM_RT_VAL_RESIZABLE)) { - type = PM_RT_RESIZABLE; - } - else { - printf("Unknown %s: %s\n", PM_TYPE, type_str); - return ATL_STATUS_FAILURE; - } - } - - switch (type) { - case PM_RT_SIMPLE: return pmirt_init(proc_idx, procs_num, pmrt_desc); - case PM_RT_RESIZABLE: - return resizable_pmirt_init(proc_idx, procs_num, pmrt_desc, main_addr); - default: printf("Wrong CCL_PM_TYPE: %s", type_str); return ATL_STATUS_FAILURE; - } -} - -static inline atl_status_t pmrt_main_addr_reserv(char *main_addr) { - return resizable_pmirt_main_addr_reserv(main_addr); -} - -static inline atl_status_t pmrt_set_resize_function(atl_resize_fn_t user_checker) { - switch (type) { - case PM_RT_RESIZABLE: return resizable_pmirt_set_resize_function(user_checker); - default: return ATL_STATUS_SUCCESS; - } -} -static inline atl_status_t pmrt_update(size_t *proc_idx, - size_t *proc_count, - pm_rt_desc_t *pmrt_desc) { - return pmrt_desc->ops->update(proc_idx, proc_count); -} -static inline atl_status_t pmrt_wait_notification(pm_rt_desc_t *pmrt_desc) { - return pmrt_desc->ops->wait_notification(); -} -static inline void pmrt_finalize(pm_rt_desc_t *pmrt_desc) { - pmrt_desc->ops->finalize(pmrt_desc); -} -static inline void pmrt_barrier(pm_rt_desc_t *pmrt_desc) { - pmrt_desc->ops->barrier(pmrt_desc); -} - -static inline atl_status_t pmrt_kvs_put(pm_rt_desc_t *pmrt_desc, - char *kvs_key, - size_t proc_idx, - const void *kvs_val, - size_t kvs_val_len) { - return pmrt_desc->kvs_ops->put(pmrt_desc, kvs_key, proc_idx, kvs_val, kvs_val_len); -} - -static inline atl_status_t pmrt_kvs_get(pm_rt_desc_t *pmrt_desc, - char *kvs_key, - size_t proc_idx, - void *kvs_val, - size_t kvs_val_len) { - return pmrt_desc->kvs_ops->get(pmrt_desc, kvs_key, proc_idx, kvs_val, kvs_val_len); -} - -} -#endif - -#ifdef __cplusplus -class ipmi { -public: - virtual ~ipmi() = default; - - virtual int is_pm_resize_enabled() = 0; - - virtual atl_status_t pmrt_main_addr_reserv(char *main_addr) = 0; - - virtual atl_status_t pmrt_set_resize_function(atl_resize_fn_t resize_fn) = 0; - - virtual atl_status_t pmrt_update() = 0; - - virtual atl_status_t pmrt_wait_notification() = 0; - - virtual void pmrt_finalize() = 0; - - virtual void pmrt_barrier() = 0; - - virtual atl_status_t pmrt_kvs_put(char *kvs_key, - size_t proc_idx, - const void *kvs_val, - size_t kvs_val_len) = 0; - - virtual atl_status_t pmrt_kvs_get(char *kvs_key, - size_t proc_idx, - void *kvs_val, - size_t kvs_val_len) = 0; - - virtual size_t get_rank() = 0; - - virtual size_t get_size() = 0; - - virtual size_t get_thread() = 0; - - virtual size_t get_local_kvs_id() = 0; - - virtual void set_local_kvs_id(size_t local_kvs_id) = 0; - - virtual size_t get_threads_count() = 0; - - virtual size_t get_devices_per_rank_count() = 0; -}; -#endif -#endif /* PM_RT_H */ +#ifndef PM_RT_H +#define PM_RT_H + +#include "atl_def.h" + +#define PM_TYPE "CCL_PM_TYPE" + +#define PM_RT_VAL_SIMPLE "simple" +#define PM_RT_VAL_RESIZABLE "resizable" + +typedef struct pm_rt_desc pm_rt_desc_t; + +//typedef enum pm_rt_type { +// PM_RT_SIMPLE = 0, +// PM_RT_RESIZABLE = 1, +//} pm_rt_type_t; +// +//static pm_rt_type_t type = PM_RT_SIMPLE; + +typedef struct pm_rt_ops { + void (*finalize)(pm_rt_desc_t *pmrt_desc); + void (*barrier)(pm_rt_desc_t *pmrt_desc); + atl_status_t (*update)(int *proc_idx, int *proc_count); + atl_status_t (*wait_notification)(void); +} pm_rt_ops_t; + +typedef struct pm_rt_kvs_ops { + atl_status_t (*put)(pm_rt_desc_t *pmrt_desc, + char *kvs_key, + int proc_idx, + const void *kvs_val, + size_t kvs_val_len); + atl_status_t (*get)(pm_rt_desc_t *pmrt_desc, + char *kvs_key, + int proc_idx, + void *kvs_val, + size_t kvs_val_len); +} pm_rt_kvs_ops_t; + +struct pm_rt_desc { + pm_rt_ops_t *ops; + pm_rt_kvs_ops_t *kvs_ops; +}; + +#if 0 +/* PMI RT */ +atl_status_t pmirt_init(int *proc_idx, int *procs_num, pm_rt_desc_t **pmrt_desc); +atl_status_t resizable_pmirt_init(int *proc_idx, + int *proc_count, + pm_rt_desc_t **pmrt_desc, + const char *main_addr); +atl_status_t resizable_pmirt_set_resize_function(atl_resize_fn_t resize_fn); +atl_status_t resizable_pmirt_main_addr_reserve(char *main_addr); + + +static inline int is_pm_resize_enabled() { + if (type == PM_RT_RESIZABLE) + return 1; + return 0; +} + +static inline atl_status_t pmrt_init(int *proc_idx, + int *procs_num, + pm_rt_desc_t **pmrt_desc, + const char *main_addr) { + char *type_str = getenv(PM_TYPE); + + if (type_str) { + if (strstr(type_str, PM_RT_VAL_SIMPLE)) { + type = PM_RT_SIMPLE; + } + else if (strstr(type_str, PM_RT_VAL_RESIZABLE)) { + type = PM_RT_RESIZABLE; + } + else { + printf("Unknown %s: %s\n", PM_TYPE, type_str); + return ATL_STATUS_FAILURE; + } + } + + switch (type) { + case PM_RT_SIMPLE: return pmirt_init(proc_idx, procs_num, pmrt_desc); + case PM_RT_RESIZABLE: + return resizable_pmirt_init(proc_idx, procs_num, pmrt_desc, main_addr); + default: printf("Wrong CCL_PM_TYPE: %s", type_str); return ATL_STATUS_FAILURE; + } +} + +static inline atl_status_t pmrt_main_addr_reserve(char *main_addr) { + return resizable_pmirt_main_addr_reserve(main_addr); +} + +static inline atl_status_t pmrt_set_resize_function(atl_resize_fn_t user_checker) { + switch (type) { + case PM_RT_RESIZABLE: return resizable_pmirt_set_resize_function(user_checker); + default: return ATL_STATUS_SUCCESS; + } +} +static inline atl_status_t pmrt_update(int *proc_idx, + int *proc_count, + pm_rt_desc_t *pmrt_desc) { + return pmrt_desc->ops->update(proc_idx, proc_count); +} +static inline atl_status_t pmrt_wait_notification(pm_rt_desc_t *pmrt_desc) { + return pmrt_desc->ops->wait_notification(); +} +static inline void pmrt_finalize(pm_rt_desc_t *pmrt_desc) { + pmrt_desc->ops->finalize(pmrt_desc); +} +static inline void pmrt_barrier(pm_rt_desc_t *pmrt_desc) { + pmrt_desc->ops->barrier(pmrt_desc); +} + +static inline atl_status_t pmrt_kvs_put(pm_rt_desc_t *pmrt_desc, + char *kvs_key, + int proc_idx, + const void *kvs_val, + size_t kvs_val_len) { + return pmrt_desc->kvs_ops->put(pmrt_desc, kvs_key, proc_idx, kvs_val, kvs_val_len); +} + +static inline atl_status_t pmrt_kvs_get(pm_rt_desc_t *pmrt_desc, + char *kvs_key, + int proc_idx, + void *kvs_val, + size_t kvs_val_len) { + return pmrt_desc->kvs_ops->get(pmrt_desc, kvs_key, proc_idx, kvs_val, kvs_val_len); +} + +} +#endif + +#ifdef __cplusplus +class ipmi { +public: + virtual ~ipmi() = default; + + virtual int is_pm_resize_enabled() = 0; + + virtual atl_status_t pmrt_main_addr_reserve(char *main_addr) = 0; + + virtual atl_status_t pmrt_set_resize_function(atl_resize_fn_t resize_fn) = 0; + + virtual atl_status_t pmrt_update() = 0; + + virtual atl_status_t pmrt_wait_notification() = 0; + + virtual void pmrt_finalize() = 0; + + virtual void pmrt_barrier() = 0; + + virtual atl_status_t pmrt_kvs_put(char *kvs_key, + int proc_idx, + const void *kvs_val, + size_t kvs_val_len) = 0; + + virtual atl_status_t pmrt_kvs_get(char *kvs_key, + int proc_idx, + void *kvs_val, + size_t kvs_val_len) = 0; + + virtual int get_rank() = 0; + + virtual int get_size() = 0; + + virtual size_t get_local_thread_idx() = 0; + + virtual size_t get_local_kvs_id() = 0; + + virtual void set_local_kvs_id(size_t local_kvs_id) = 0; + + virtual size_t get_threads_per_process() = 0; + + virtual size_t get_ranks_per_process() = 0; +}; +#endif +#endif /* PM_RT_H */ diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.cpp index 3083d1bf8..ec1f9ba30 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.cpp @@ -20,7 +20,7 @@ #include "util/pm/codec/pm_rt_codec.h" #include "pmi_resizable.h" -#define RESIZABLE_PMI_RT_KEY_FORMAT "%s-%zu" +#define RESIZABLE_PMI_RT_KEY_FORMAT "%s-%d" int pmi_resizable::is_pm_resize_enabled() { return true; @@ -85,8 +85,8 @@ atl_status_t pmi_resizable::pmrt_init(const char *main_addr) { return ATL_STATUS_FAILURE; } -atl_status_t pmi_resizable::pmrt_main_addr_reserv(char *main_addr) { - int ret = PMIR_Main_Addr_Reserv(main_addr); +atl_status_t pmi_resizable::pmrt_main_addr_reserve(char *main_addr) { + int ret = PMIR_Main_Addr_Reserve(main_addr); if (ret) return ATL_STATUS_FAILURE; @@ -155,7 +155,7 @@ void pmi_resizable::pmrt_barrier() { } atl_status_t pmi_resizable::pmrt_kvs_put(char *kvs_key, - size_t proc_idx, + int proc_idx, const void *kvs_val, size_t kvs_val_len) { int ret; @@ -186,7 +186,7 @@ atl_status_t pmi_resizable::pmrt_kvs_put(char *kvs_key, } atl_status_t pmi_resizable::pmrt_kvs_get(char *kvs_key, - size_t proc_idx, + int proc_idx, void *kvs_val, size_t kvs_val_len) { int ret; diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.h index 9a518f31d..8ec4023d0 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.h @@ -18,7 +18,7 @@ #include "atl/atl_def.h" #include "atl/util/pm/pm_rt.h" #include "atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h" -#include "atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.h" +#include "atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp" #include "atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.hpp" #define PMIR_SUCCESS 0 @@ -42,7 +42,7 @@ typedef enum { KVS_RA_RUN = 1, KVS_RA_FINALIZE = 2, } kvs_resize_action_t; -typedef kvs_resize_action_t (*pmir_resize_fn_t)(size_t comm_size); +typedef kvs_resize_action_t (*pmir_resize_fn_t)(int comm_size); class helper; class pmi_resizable final : public ipmi { @@ -58,7 +58,7 @@ class pmi_resizable final : public ipmi { int is_pm_resize_enabled() override; - atl_status_t pmrt_main_addr_reserv(char* main_addr) override; + atl_status_t pmrt_main_addr_reserve(char* main_addr) override; atl_status_t pmrt_set_resize_function(atl_resize_fn_t resize_fn) override; @@ -69,32 +69,32 @@ class pmi_resizable final : public ipmi { void pmrt_barrier() override; atl_status_t pmrt_kvs_put(char* kvs_key, - size_t proc_idx, + int proc_idx, const void* kvs_val, size_t kvs_val_len) override; atl_status_t pmrt_kvs_get(char* kvs_key, - size_t proc_idx, + int proc_idx, void* kvs_val, size_t kvs_val_len) override; void Hard_finilize(int sig); - size_t get_rank() override; + int get_rank() override; - size_t get_size() override; + int get_size() override; - size_t get_thread() override; + size_t get_local_thread_idx() override; size_t get_local_kvs_id() override; void set_local_kvs_id(size_t local_kvs_id) override; - size_t get_threads_count() override { + size_t get_threads_per_process() override { return 1; } - size_t get_devices_per_rank_count() override { + size_t get_ranks_per_process() override { return 1; } void pmrt_finalize() override; @@ -103,15 +103,15 @@ class pmi_resizable final : public ipmi { bool is_finalized{ false }; atl_status_t pmrt_init(const char* main_addr = nullptr); /*Was in API ->*/ - int PMIR_Main_Addr_Reserv(char* main_addr); + int PMIR_Main_Addr_Reserve(char* main_addr); int PMIR_Init(const char* main_addr); int PMIR_Finalize(void); - int PMIR_Get_size(size_t* size); + int PMIR_Get_size(int* size); - int PMIR_Get_rank(size_t* rank); + int PMIR_Get_rank(int* rank); int PMIR_KVS_Get_my_name(char* kvs_name, size_t length); @@ -135,10 +135,11 @@ class pmi_resizable final : public ipmi { int PMIR_Wait_notification(void); /* <- Was in API*/ - kvs_resize_action_t default_checker(size_t comm_size); - kvs_resize_action_t call_resize_fn(size_t comm_size); - size_t rank; - size_t size; + kvs_resize_action_t default_checker(int comm_size); + kvs_resize_action_t call_resize_fn(int comm_size); + + int rank; + int size; pmir_resize_fn_t resize_function = nullptr; std::shared_ptr h; diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h index 5bfde612c..564a309f7 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h @@ -13,16 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef DEF_INCLUDED -#define DEF_INCLUDED - -#include +#pragma once //TODO: change exit to something more useful #define SET_STR(dst, size, ...) \ do { \ if (snprintf(dst, size, __VA_ARGS__) > size) { \ - printf("Line so big (must be low %d)\n", size); \ + printf("line too long (must be shorter %d)\n", size); \ printf(__VA_ARGS__); \ exit(1); \ } \ @@ -37,41 +34,65 @@ } \ } while (0) -#define DO_RW_OP(op, fd, buf, size) \ +#define DO_RW_OP(op, fd, buf, size, memory_mutex, msg) \ do { \ - ssize_t res = 0; \ - size_t shift = 0; \ - while (shift != size) { \ - res = op(fd, (char*)buf + shift, size - shift); \ - if (res == -1) { \ - if (errno != EINTR) { \ - printf("read/write error: %s\n", strerror(errno)); \ + { \ + if (!fd) { \ + printf("" #msg ": " #op ": fd is closed, size %zu\n", size); \ + break; \ + } \ + std::lock_guard lock(memory_mutex); \ + ssize_t res = 0; \ + size_t shift = 0; \ + while (shift != size) { \ + res = op(fd, (char*)buf + shift, size - shift); \ + if (res == -1) { \ + if (errno != EINTR) { \ + printf("" #msg ": " #op ": error: buf %p, size %zu, shift %zu\n", \ + buf, \ + size, \ + shift); \ + perror("read/write error"); \ + exit(EXIT_FAILURE); \ + } \ + } \ + else if (res == 0) { \ + printf("" #msg ": " #op ": can not process all data, size %zu, shift %zu\n", \ + size, \ + shift); \ exit(EXIT_FAILURE); \ } \ - } \ - else { \ - shift += res; \ + else { \ + shift += res; \ + } \ } \ } \ } while (0) - -#define DO_RW_OP_1(op, fd, buf, size, res) \ +#define DO_RW_OP_1(op, fd, buf, size, res, msg) \ do { \ + if (!fd) { \ + printf("" #msg ": " #op ": fd is closed, size %zu\n", size); \ + break; \ + } \ size_t shift = 0; \ res = 0; \ do { \ res = op(fd, (char*)buf + shift, size - shift); \ if (res == -1) { \ if (errno != EINTR) { \ - printf("read/write error: %s\n", strerror(errno)); \ + printf("" #msg ": " #op ": error: buf %p, size %zu, shift %zu\n", \ + buf, \ + size, \ + shift); \ + perror("read/write error"); \ exit(EXIT_FAILURE); \ } \ - }\ + } \ else { \ shift += res; \ } \ - } while (shift != size && res != 0); \ + } while ((shift != size) && (res != 0)); \ } while (0) #define BARRIER_NUM_MAX 1024 @@ -92,10 +113,12 @@ #define GREP_TEMPLATE "| grep \"%s\"" #define GREP_COUNT_TEMPLATE "| grep -c \"%s\"" #define CONCAT_TWO_COMMAND_TEMPLATE "%s %s" +#define RANK_TEMPLATE "%d" #define SIZE_T_TEMPLATE "%zu" -#define KVS_NAME "CCL_POD_ADDR" -#define KVS_BARRIER "CCL_BARRIER" +#define KVS_NAME "CCL_POD_ADDR" +#define KVS_BARRIER "CCL_BARRIER" +#define KVS_BARRIER_FULL "CCL_BARRIER_FULL" #define KVS_IDX "IDX" #define KVS_UP "CCL_UP" @@ -109,7 +132,7 @@ #define CCL_IP_LEN 128 -#define CHECKER_IP "hostname -I" +#define GET_IP_CMD "hostname -I" #define READ_ONLY "r" #define NULL_CHAR '\0' #define MAX_UP_IDX 2048 @@ -117,8 +140,4 @@ #define INITIAL_RANK_NUM "0" #define MAX_CLEAN_CHECKS 3 -#define STR_COPY(dst, src, len) { memcpy((dst), (src), (len-1)); dst[len - 1] = '\0'; } - extern char my_hostname[MAX_KVS_VAL_LENGTH]; - -#endif diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.cpp index e7ad2365b..c2a0c67fd 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.cpp @@ -13,46 +13,55 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "helper.h" +#include + +#include "util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp" #include "util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h" -size_t my_rank, count_pods; +int my_rank, count_pods; size_t barrier_num = 0; size_t up_idx; size_t applied = 0; rank_list_t* killed_ranks = NULL; -size_t killed_ranks_count = 0; +int killed_ranks_count = 0; rank_list_t* new_ranks = NULL; -size_t new_ranks_count = 0; +int new_ranks_count = 0; + +void kvs_str_copy(char* dst, const char* src, size_t bytes) { + strncpy(dst, src, bytes - 1); + dst[bytes - 1] = '\0'; +} + +size_t helper::replace_str(char* str, int old_rank, int new_rank) { + throw std::runtime_error("unexpected path"); -size_t helper::replace_str(char* str, size_t old_rank, size_t new_rank) { char old_str[INT_STR_SIZE]; char new_str[INT_STR_SIZE]; char* point_to_replace; - size_t old_str_size; - size_t new_str_size; - - SET_STR(old_str, INT_STR_SIZE, SIZE_T_TEMPLATE, old_rank); + int old_str_size; + int new_str_size; - SET_STR(new_str, INT_STR_SIZE, SIZE_T_TEMPLATE, new_rank); + SET_STR(old_str, INT_STR_SIZE, RANK_TEMPLATE, old_rank); + SET_STR(new_str, INT_STR_SIZE, RANK_TEMPLATE, new_rank); point_to_replace = strstr(str, old_str); if (point_to_replace == NULL) return 1; + old_str_size = strlen(old_str); new_str_size = strlen(new_str); if (old_str_size != new_str_size) { - size_t rest_len = strlen(point_to_replace); + size_t rest_len = strlen(point_to_replace) - old_str_size; memmove(point_to_replace + new_str_size, point_to_replace + old_str_size, rest_len); } - STR_COPY(point_to_replace, new_str, new_str_size); + memcpy(point_to_replace, new_str, new_str_size); return 0; } -void helper::update_ranks(size_t* old_count, rank_list_t** origin_list, const char* kvs_name) { +void helper::update_ranks(int* old_count, rank_list_t** origin_list, const char* kvs_name) { char** rank_nums = NULL; size_t rank_count = get_keys_values_by_name(kvs_name, NULL, &rank_nums); size_t i; @@ -79,7 +88,7 @@ void helper::update_ranks(size_t* old_count, rank_list_t** origin_list, const ch *old_count += cur_count; } -void helper::keep_first_n_up(size_t prev_new_ranks_count, size_t prev_killed_ranks_count) { +void helper::keep_first_n_up(int prev_new_ranks_count, int prev_killed_ranks_count) { rank_list_keep_first_n(&killed_ranks, prev_killed_ranks_count); rank_list_keep_first_n(&new_ranks, prev_new_ranks_count); } @@ -90,8 +99,8 @@ void helper::get_update_ranks(void) { } void helper::get_shift(shift_list_t** list) { - size_t shift_pods_count = 0; - size_t max_rank_survivor_pod = count_pods; + int shift_pods_count = 0; + int max_rank_survivor_pod = count_pods; rank_list_t* cur_new = new_ranks; rank_list_t* cur_killed = killed_ranks; @@ -184,8 +193,8 @@ void helper::accept_new_ranks(shift_list_t* cur_list) { while (cur_list != NULL) { if (cur_list->shift.type == CH_T_UPDATE || cur_list->shift.type == CH_T_NEW) { - SET_STR(old_rank_str, INT_STR_SIZE, SIZE_T_TEMPLATE, cur_list->shift.old_rank); - SET_STR(new_rank_str, INT_STR_SIZE, SIZE_T_TEMPLATE, cur_list->shift.new_rank); + SET_STR(old_rank_str, INT_STR_SIZE, RANK_TEMPLATE, cur_list->shift.old_rank); + SET_STR(new_rank_str, INT_STR_SIZE, RANK_TEMPLATE, cur_list->shift.new_rank); count_values = get_keys_values_by_name(KVS_APPROVED_NEW_POD, &kvs_keys, &kvs_values); @@ -215,7 +224,7 @@ void helper::accept_new_ranks(shift_list_t* cur_list) { free(kvs_values); } -void helper::update_kvs_info(size_t new_rank) { +void helper::update_kvs_info(int new_rank) { char kvs_name[MAX_KVS_NAME_LENGTH]; char kvs_key[MAX_KVS_KEY_LENGTH]; char kvs_val[MAX_KVS_VAL_LENGTH]; @@ -235,13 +244,13 @@ void helper::update_kvs_info(size_t new_rank) { } } -void helper::move_to_new_rank(size_t new_rank) { +void helper::move_to_new_rank(int new_rank) { char rank_str[INT_STR_SIZE]; update_kvs_info(new_rank); my_rank = new_rank; - SET_STR(rank_str, INT_STR_SIZE, SIZE_T_TEMPLATE, my_rank); + SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); // request_set_val(KVS_POD_REQUEST, my_hostname, rank_str); @@ -253,10 +262,10 @@ void helper::update_my_info(shift_list_t* list) { while (list != NULL) { if (list->shift.old_rank == my_rank) { - size_t old_rank = my_rank; + int old_rank = my_rank; move_to_new_rank(list->shift.new_rank); - SET_STR(rank_str, INT_STR_SIZE, SIZE_T_TEMPLATE, old_rank); + SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, old_rank); remove_name_key(KVS_POD_NUM, rank_str); @@ -296,7 +305,7 @@ void helper::post_my_info(void) { applied = 1; - SET_STR(my_rank_str, INT_STR_SIZE, SIZE_T_TEMPLATE, my_rank); + SET_STR(my_rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); set_value(KVS_POD_NUM, my_rank_str, my_hostname); @@ -315,7 +324,7 @@ void helper::post_my_info(void) { barrier_num = 0; } -size_t helper::update(shift_list_t** list, rank_list_t** dead_up_idx, size_t root_rank) { +size_t helper::update(shift_list_t** list, rank_list_t** dead_up_idx, int root_rank) { if (applied == 1) { if ((*list) != NULL) { if (my_rank == root_rank) { @@ -409,7 +418,7 @@ void helper::reg_rank(void) { my_rank = 0; set_value(KVS_POD_REQUEST, my_hostname, INITIAL_RANK_NUM); - SET_STR(rank_str, INT_STR_SIZE, SIZE_T_TEMPLATE, my_rank); + SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); while (1) { wait_shift = 0; @@ -444,7 +453,7 @@ void helper::reg_rank(void) { if (!wait_shift) { my_rank++; - SET_STR(rank_str, INT_STR_SIZE, SIZE_T_TEMPLATE, my_rank); + SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); set_value(KVS_POD_REQUEST, my_hostname, rank_str); } } @@ -487,13 +496,13 @@ void helper::up_kvs_new_and_dead(void) { up_kvs(KVS_APPROVED_DEAD_POD, KVS_DEAD_POD); } -void helper::get_new_root(size_t* old_root) { +void helper::get_new_root(int* old_root) { size_t i; char** rank_nums = NULL; size_t rank_count = get_keys_values_by_name(KVS_DEAD_POD, NULL, &rank_nums); for (i = 0; i < rank_count; i++) { - if (*old_root == (size_t)strtol(rank_nums[i], NULL, 10)) + if (*old_root == (int)strtol(rank_nums[i], NULL, 10)) (*old_root)++; free(rank_nums[i]); } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp similarity index 81% rename from src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.h rename to src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp index db61f4c8b..1dd5e4be9 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp @@ -25,22 +25,25 @@ #include #include #include + #include "def.h" -#include "rank_list.h" -#include "shift_list.h" -#include "kvs_keeper.h" +#include "rank_list.hpp" +#include "shift_list.hpp" +#include "kvs_keeper.hpp" #include "kvs/ikvs_wrapper.h" -extern size_t my_rank, count_pods; +extern int my_rank, count_pods; extern size_t barrier_num; extern size_t up_idx; extern size_t applied; extern rank_list_t* killed_ranks; -extern size_t killed_ranks_count; +extern int killed_ranks_count; extern rank_list_t* new_ranks; -extern size_t new_ranks_count; +extern int new_ranks_count; + +void kvs_str_copy(char* dst, const char* src, size_t bytes); class helper { public: @@ -54,7 +57,7 @@ class helper { void wait_accept(void); - size_t update(shift_list_t** list, rank_list_t** dead_up_idx, size_t root_rank); + size_t update(shift_list_t** list, rank_list_t** dead_up_idx, int root_rank); void up_pods_count(void); @@ -66,9 +69,9 @@ class helper { void up_kvs_new_and_dead(void); - void keep_first_n_up(size_t prev_new_ranks_count, size_t prev_killed_ranks_count); + void keep_first_n_up(int prev_new_ranks_count, int prev_killed_ranks_count); - void get_new_root(size_t* old_root); + void get_new_root(int* old_root); /*Work with KVS, new*/ size_t set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val); @@ -90,17 +93,17 @@ class helper { /*Work with KVS, new*/ private: - size_t replace_str(char* str, size_t old_rank, size_t new_rank); + size_t replace_str(char* str, int old_rank, int new_rank); - void update_ranks(size_t* old_count, rank_list_t** origin_list, const char* kvs_name); + void update_ranks(int* old_count, rank_list_t** origin_list, const char* kvs_name); void clean_dead_pods_info(rank_list_t* dead_up_idx); void accept_new_ranks(shift_list_t* cur_list); - void update_kvs_info(size_t new_rank); + void update_kvs_info(int new_rank); - void move_to_new_rank(size_t new_rank); + void move_to_new_rank(int new_rank); void update_my_info(shift_list_t* list); diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h index 26770913c..6f68a78a6 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once + #include class ikvs_wrapper { diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp index cceca2e6d..e0cb5e518 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -24,10 +25,11 @@ #include #include +#include "util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp" #include "util/pm/pmi_resizable_rt/pmi_resizable/def.h" #include "internal_kvs.h" -#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.h" -#include "util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.h" +#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.hpp" +#include "util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.hpp" #define CCL_KVS_IP_PORT_ENV "CCL_KVS_IP_PORT" #define CCL_KVS_IP_EXCHANGE_ENV "CCL_KVS_IP_EXCHANGE" @@ -37,13 +39,27 @@ #define MAX_CLIENT_COUNT 300 #define CONNECTION_TIMEOUT 120 -static pthread_t thread = 0; +static pthread_t kvs_thread = 0; + static char main_host_ip[CCL_IP_LEN]; char local_host_ip[CCL_IP_LEN]; -static int sock_listener = 0; + static size_t main_port; static size_t local_port; static size_t is_master = 0; +static std::mutex client_memory_mutex; +static std::mutex server_memory_mutex; + +static struct sockaddr_in main_server_address; +static struct sockaddr_in local_server_address; + +static int + client_op_sock; /* used on client side to send commands and to recv result to/from server */ +static int + server_listen_sock; /* used on server side to handle new incoming connect requests from clients */ + +static int client_control_sock; /* used on client side to control local kvs server */ +static int server_control_sock; /* used on server side to be controlled by local client */ typedef enum ip_getting_type { IGT_K8S = 0, @@ -54,7 +70,7 @@ static ip_getting_type_t ip_getting_mode = IGT_K8S; typedef enum kvs_access_mode { AM_CONNECT = -1, -// AM_DISCONNECT = 1, + // AM_DISCONNECT = 1, AM_PUT = 2, AM_REMOVE = 3, AM_GET_COUNT = 4, @@ -71,19 +87,20 @@ typedef struct kvs_request { char val[MAX_KVS_VAL_LENGTH]; } kvs_request_t; -static struct sockaddr_in main_server_address; -static struct sockaddr_in local_server_address; -static int sock_sender, local_sock, accepted_local_sock; - size_t internal_kvs::kvs_set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val) { kvs_request_t request; memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_PUT; - STR_COPY(request.name, kvs_name, MAX_KVS_NAME_LENGTH); - STR_COPY(request.key, kvs_key, MAX_KVS_KEY_LENGTH); - STR_COPY(request.val, kvs_val, MAX_KVS_VAL_LENGTH); + kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); + kvs_str_copy(request.key, kvs_key, MAX_KVS_KEY_LENGTH); + kvs_str_copy(request.val, kvs_val, MAX_KVS_VAL_LENGTH); - DO_RW_OP(write, sock_sender, &request, sizeof(kvs_request_t)); + DO_RW_OP(write, + client_op_sock, + &request, + sizeof(kvs_request_t), + client_memory_mutex, + "client: put_key_value"); return 0; } @@ -92,10 +109,15 @@ size_t internal_kvs::kvs_remove_name_key(const char* kvs_name, const char* kvs_k kvs_request_t request; memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_REMOVE; - STR_COPY(request.name, kvs_name, MAX_KVS_NAME_LENGTH); - STR_COPY(request.key, kvs_key, MAX_KVS_KEY_LENGTH); + kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); + kvs_str_copy(request.key, kvs_key, MAX_KVS_KEY_LENGTH); - DO_RW_OP(write, sock_sender, &request, sizeof(kvs_request_t)); + DO_RW_OP(write, + client_op_sock, + &request, + sizeof(kvs_request_t), + client_memory_mutex, + "client: remove_key"); return 0; } @@ -107,16 +129,28 @@ size_t internal_kvs::kvs_get_value_by_name_key(const char* kvs_name, memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_GET_VAL; size_t is_exist = 0; - STR_COPY(request.name, kvs_name, MAX_KVS_NAME_LENGTH); - STR_COPY(request.key, kvs_key, MAX_KVS_KEY_LENGTH); + kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); + kvs_str_copy(request.key, kvs_key, MAX_KVS_KEY_LENGTH); memset(kvs_val, 0, MAX_KVS_VAL_LENGTH); - DO_RW_OP(write, sock_sender, &request, sizeof(kvs_request_t)); + DO_RW_OP( + write, client_op_sock, &request, sizeof(request), client_memory_mutex, "client: get_value"); + + DO_RW_OP(read, + client_op_sock, + &is_exist, + sizeof(is_exist), + client_memory_mutex, + "client: get_value is_exist"); - DO_RW_OP(read, sock_sender, &is_exist, sizeof(size_t)); if (is_exist) { - DO_RW_OP(read, sock_sender, &request, sizeof(kvs_request_t)); - STR_COPY(kvs_val, request.val, MAX_KVS_VAL_LENGTH); + DO_RW_OP(read, + client_op_sock, + &request, + sizeof(request), + client_memory_mutex, + "client: get_value read data"); + kvs_str_copy(kvs_val, request.val, MAX_KVS_VAL_LENGTH); } return strlen(kvs_val); @@ -127,11 +161,21 @@ size_t internal_kvs::kvs_get_count_names(const char* kvs_name) { kvs_request_t request; memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_GET_COUNT; - STR_COPY(request.name, kvs_name, MAX_KVS_NAME_LENGTH); - - DO_RW_OP(write, sock_sender, &request, sizeof(kvs_request_t)); - - DO_RW_OP(read, sock_sender, &count_names, sizeof(size_t)); + kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); + + DO_RW_OP(write, + client_op_sock, + &request, + sizeof(kvs_request_t), + client_memory_mutex, + "client: get_count"); + + DO_RW_OP(read, + client_op_sock, + &count_names, + sizeof(size_t), + client_memory_mutex, + "client: get_count read data"); return count_names; } @@ -146,26 +190,44 @@ size_t internal_kvs::kvs_get_keys_values_by_name(const char* kvs_name, memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_GET_KEYS_VALUES; - STR_COPY(request.name, kvs_name, MAX_KVS_NAME_LENGTH); - - DO_RW_OP(write, sock_sender, &request, sizeof(kvs_request_t)); - - DO_RW_OP(read, sock_sender, &count, sizeof(size_t)); + kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); + + DO_RW_OP(write, + client_op_sock, + &request, + sizeof(kvs_request_t), + client_memory_mutex, + "client: get_keys_values"); + + DO_RW_OP(read, + client_op_sock, + &count, + sizeof(size_t), + client_memory_mutex, + "client: get_keys_values read size"); if (count == 0) return count; answers = (kvs_request_t*)calloc(count, sizeof(kvs_request_t)); - - DO_RW_OP(read, sock_sender, answers, sizeof(kvs_request_t) * count); + DO_RW_OP(read, + client_op_sock, + answers, + sizeof(kvs_request_t) * count, + client_memory_mutex, + "client: get_keys_values read data"); if (kvs_keys != NULL) { if (*kvs_keys != NULL) free(*kvs_keys); *kvs_keys = (char**)calloc(count, sizeof(char*)); + if ((*kvs_keys) == NULL) { + printf("Memory allocation failed\n"); + exit(1); + } for (i = 0; i < count; i++) { (*kvs_keys)[i] = (char*)calloc(MAX_KVS_KEY_LENGTH, sizeof(char)); - STR_COPY((*kvs_keys)[i], answers[i].key, MAX_KVS_KEY_LENGTH); + kvs_str_copy((*kvs_keys)[i], answers[i].key, MAX_KVS_KEY_LENGTH); } } if (kvs_values != NULL) { @@ -173,9 +235,13 @@ size_t internal_kvs::kvs_get_keys_values_by_name(const char* kvs_name, free(*kvs_values); *kvs_values = (char**)calloc(count, sizeof(char*)); + if ((*kvs_values) == NULL) { + printf("Memory allocation failed\n"); + exit(1); + } for (i = 0; i < count; i++) { (*kvs_values)[i] = (char*)calloc(MAX_KVS_VAL_LENGTH, sizeof(char)); - STR_COPY((*kvs_values)[i], answers[i].val, MAX_KVS_VAL_LENGTH); + kvs_str_copy((*kvs_values)[i], answers[i].val, MAX_KVS_VAL_LENGTH); } } @@ -194,40 +260,51 @@ size_t internal_kvs::kvs_get_replica_size(void) { memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_GET_REPLICA; - DO_RW_OP(write, sock_sender, &request, sizeof(kvs_request_t)); + DO_RW_OP(write, + client_op_sock, + &request, + sizeof(kvs_request_t), + client_memory_mutex, + "client: get_replica"); - DO_RW_OP(read, sock_sender, &replica_size, sizeof(size_t)); + DO_RW_OP(read, + client_op_sock, + &replica_size, + sizeof(size_t), + client_memory_mutex, + "client: get_replica read size"); } return replica_size; } void* kvs_server_init(void* args) { struct sockaddr_in addr; - int local_sock; + int server_control_sock; kvs_request_t request; size_t count; - size_t clients_count = 0; - int is_stop = 0; + size_t client_count = 0; + int should_stop = 0; fd_set read_fds; - int i, client_socket[MAX_CLIENT_COUNT], max_sd, sd; + int i, client_op_sockets[MAX_CLIENT_COUNT], max_sd, sd; int so_reuse = 1; int ret = 0; + #ifdef SO_REUSEPORT - setsockopt(sock_listener, SOL_SOCKET, SO_REUSEPORT, &so_reuse, sizeof(so_reuse)); + setsockopt(server_listen_sock, SOL_SOCKET, SO_REUSEPORT, &so_reuse, sizeof(so_reuse)); #else - setsockopt(sock_listener, SOL_SOCKET, SO_REUSEADDR, &so_reuse, sizeof(so_reuse)); + setsockopt(server_listen_sock, SOL_SOCKET, SO_REUSEADDR, &so_reuse, sizeof(so_reuse)); #endif for (i = 0; i < MAX_CLIENT_COUNT; i++) { - client_socket[i] = 0; + client_op_sockets[i] = 0; } - if ((local_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - printf("Server: socket init failed - %s\n", strerror(errno)); - exit(1); + if ((server_control_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + perror("server: server_control_sock init"); + exit(EXIT_FAILURE); } - while (connect(local_sock, (struct sockaddr*)args, sizeof(addr)) < 0) { + while (connect(server_control_sock, (struct sockaddr*)args, sizeof(addr)) < 0) { } memset(&addr, 0, sizeof(addr)); @@ -236,18 +313,19 @@ void* kvs_server_init(void* args) { addr.sin_addr.s_addr = INADDR_ANY; addr.sin_port = 0; - if (listen(sock_listener, MAX_CLIENT_COUNT) < 0) { - perror("listen"); + if (listen(server_listen_sock, MAX_CLIENT_COUNT) < 0) { + perror("server: server_listen_sock listen"); exit(EXIT_FAILURE); } - while (!is_stop || clients_count > 1) { + while (!should_stop || client_count > 1) { FD_ZERO(&read_fds); - FD_SET(sock_listener, &read_fds); - FD_SET(local_sock, &read_fds); - max_sd = sock_listener; + FD_SET(server_listen_sock, &read_fds); + FD_SET(server_control_sock, &read_fds); + max_sd = server_listen_sock; + for (i = 0; i < MAX_CLIENT_COUNT; i++) { - sd = client_socket[i]; + sd = client_op_sockets[i]; if (sd > 0) FD_SET(sd, &read_fds); @@ -255,42 +333,61 @@ void* kvs_server_init(void* args) { if (sd > max_sd) max_sd = sd; } - if (local_sock > max_sd) - max_sd = local_sock; - if ((select(max_sd + 1, &read_fds, NULL, NULL, NULL) < 0) && (errno != EINTR)) { - perror("select"); - exit(EXIT_FAILURE); + + if (server_control_sock > max_sd) + max_sd = server_control_sock; + + if (select(max_sd + 1, &read_fds, NULL, NULL, NULL) < 0) { + if (errno != EINTR) { + perror("server: select"); + exit(EXIT_FAILURE); + } + else { + /* restart select */ + continue; + } } - if (FD_ISSET(local_sock, &read_fds)) { - DO_RW_OP_1(read, local_sock, &request,sizeof(kvs_request_t), ret); + if (FD_ISSET(server_control_sock, &read_fds)) { + DO_RW_OP_1(read, + server_control_sock, + &request, + sizeof(kvs_request_t), + ret, + "server: get control msg from client"); if (ret == 0) { - close(local_sock); - local_sock = 0; + close(server_control_sock); + server_control_sock = 0; } if (request.mode != AM_FINALIZE) { - printf("server: Wrong access mode for local socket.\n"); - exit(1); + printf("server: invalid access mode for local socket\n"); + exit(EXIT_FAILURE); } - is_stop = 1; + should_stop = 1; } + for (i = 0; i < MAX_CLIENT_COUNT; i++) { - sd = client_socket[i]; + sd = client_op_sockets[i]; if (sd == 0) continue; if (FD_ISSET(sd, &read_fds)) { - DO_RW_OP_1(read, sd, &request,sizeof(kvs_request_t), ret); + DO_RW_OP_1(read, + sd, + &request, + sizeof(kvs_request_t), + ret, + "server: get command from client"); if (ret == 0) { close(sd); - client_socket[i] = 0; - clients_count--; + client_op_sockets[i] = 0; + client_count--; continue; } switch (request.mode) { case AM_CONNECT: { - clients_count++; + client_count++; break; } case AM_PUT: { @@ -303,21 +400,41 @@ void* kvs_server_init(void* args) { } case AM_GET_VAL: { count = get_val(request.name, request.key, request.val, ST_SERVER); - DO_RW_OP(write, client_socket[i], &count, sizeof(size_t)); + DO_RW_OP(write, + client_op_sockets[i], + &count, + sizeof(size_t), + server_memory_mutex, + "server: get_value write size"); if (count != 0) - DO_RW_OP(write, client_socket[i], &request, sizeof(kvs_request_t)); + DO_RW_OP(write, + client_op_sockets[i], + &request, + sizeof(kvs_request_t), + server_memory_mutex, + "server: get_value write data"); break; } case AM_GET_COUNT: { count = get_count(request.name, ST_SERVER); - DO_RW_OP(write, client_socket[i], &count, sizeof(size_t)); + DO_RW_OP(write, + client_op_sockets[i], + &count, + sizeof(size_t), + server_memory_mutex, + "server: get_count"); break; } case AM_GET_REPLICA: { char* replica_size_str = getenv(CCL_WORLD_SIZE_ENV); count = (replica_size_str != NULL) ? strtol(replica_size_str, NULL, 10) - : clients_count; - DO_RW_OP(write, client_socket[i], &count, sizeof(size_t)); + : client_count; + DO_RW_OP(write, + client_op_sockets[i], + &count, + sizeof(size_t), + server_memory_mutex, + "server: get_replica"); break; } case AM_GET_KEYS_VALUES: { @@ -326,22 +443,34 @@ void* kvs_server_init(void* args) { size_t j; kvs_request_t* answers = NULL; - count = - get_keys_values(request.name, &kvs_keys, &kvs_values, ST_SERVER); + count = get_keys_values(request.name, &kvs_keys, &kvs_values, ST_SERVER); - DO_RW_OP(write, client_socket[i], &count, sizeof(size_t)); + DO_RW_OP(write, + client_op_sockets[i], + &count, + sizeof(size_t), + server_memory_mutex, + "server: get_keys_values write size"); if (count == 0) break; answers = (kvs_request_t*)calloc(count, sizeof(kvs_request_t)); + if (answers == NULL) { + printf("Memory allocation failed\n"); + break; + } for (j = 0; j < count; j++) { - STR_COPY(answers[j].name, request.name, MAX_KVS_NAME_LENGTH); - STR_COPY(answers[j].key, kvs_keys[j], MAX_KVS_KEY_LENGTH); - STR_COPY(answers[j].val, kvs_values[j], MAX_KVS_VAL_LENGTH); + kvs_str_copy(answers[j].name, request.name, MAX_KVS_NAME_LENGTH); + kvs_str_copy(answers[j].key, kvs_keys[j], MAX_KVS_KEY_LENGTH); + kvs_str_copy(answers[j].val, kvs_values[j], MAX_KVS_VAL_LENGTH); } - DO_RW_OP( - write, client_socket[i], answers, sizeof(kvs_request_t) * count); + DO_RW_OP(write, + client_op_sockets[i], + answers, + sizeof(kvs_request_t) * count, + server_memory_mutex, + "server: get_keys_values write data"); free(answers); for (j = 0; j < count; j++) { @@ -355,41 +484,59 @@ void* kvs_server_init(void* args) { default: { if (request.name[0] == '\0') continue; - printf("server: Unknown request mode - %d.\n", request.mode); - exit(1); + printf("server: unknown request mode - %d.\n", request.mode); + exit(EXIT_FAILURE); } } } } - if (FD_ISSET(sock_listener, &read_fds)) { + + if (FD_ISSET(server_listen_sock, &read_fds)) { int new_socket; - socklen_t peer_addr_size = sizeof(addr); - if ((new_socket = accept(sock_listener, (struct sockaddr*)&addr, (socklen_t*)&peer_addr_size)) < + socklen_t peer_addr_size = sizeof(addr); + if ((new_socket = accept( + server_listen_sock, (struct sockaddr*)&addr, (socklen_t*)&peer_addr_size)) < 0) { - perror("accept"); + perror("server: server_listen_sock accept"); exit(EXIT_FAILURE); } for (i = 0; i < MAX_CLIENT_COUNT; i++) { - if (client_socket[i] == 0) { - client_socket[i] = new_socket; + if (client_op_sockets[i] == 0) { + client_op_sockets[i] = new_socket; break; } } if (i >= MAX_CLIENT_COUNT) { - printf("server: Not enough free sockets\n"); - exit(1); + printf("server: no free sockets\n"); + exit(EXIT_FAILURE); } } } kvs_keeper_clear(ST_SERVER); - DO_RW_OP(write, local_sock, &is_stop, sizeof(int)); - close(local_sock); + + if (server_control_sock) { + DO_RW_OP_1(write, + server_control_sock, + &should_stop, + sizeof(should_stop), + ret, + "server: send control msg to client"); + } + + close(server_control_sock); + server_control_sock = 0; + for (i = 0; i < MAX_CLIENT_COUNT; i++) { - if (client_socket[i] != 0) - close(client_socket[i]); + if (client_op_sockets[i] != 0) { + close(client_op_sockets[i]); + client_op_sockets[i] = 0; + } } - close(sock_listener); + + close(server_listen_sock); + server_listen_sock = 0; + return NULL; } @@ -404,7 +551,7 @@ size_t init_main_server_by_k8s(void) { main_port = strtol(port_str, NULL, 10); main_server_address.sin_port = main_port; if (inet_pton(AF_INET, main_host_ip, &(main_server_address.sin_addr)) <= 0) { - printf("\nInvalid address/ Address not supported: %s\n", main_host_ip); + printf("invalid address/ address not supported: %s\n", main_host_ip); return 1; } return 0; @@ -417,14 +564,14 @@ size_t init_main_server_by_env(void) { tmp_host_ip = getenv(CCL_KVS_IP_PORT_ENV); if (tmp_host_ip == NULL) { - printf("You must set %s\n", CCL_KVS_IP_PORT_ENV); + printf("specify %s\n", CCL_KVS_IP_PORT_ENV); return 1; } memset(main_host_ip, 0, CCL_IP_LEN); - STR_COPY(main_host_ip, tmp_host_ip, CCL_IP_LEN); + kvs_str_copy(main_host_ip, tmp_host_ip, CCL_IP_LEN); if ((port = strstr(main_host_ip, "_")) == NULL) { - printf("You must set %s like IP_PORT\n", CCL_KVS_IP_PORT_ENV); + printf("set %s in format _\n", CCL_KVS_IP_PORT_ENV); return 1; } port[0] = '\0'; @@ -434,7 +581,7 @@ size_t init_main_server_by_env(void) { main_server_address.sin_port = main_port; if (inet_pton(AF_INET, main_host_ip, &(main_server_address.sin_addr)) <= 0) { - printf("\nInvalid address/ Address not supported: %s\n", main_host_ip); + printf("ivalid address / address not supported: %s\n", main_host_ip); return 1; } return 0; @@ -448,22 +595,22 @@ size_t init_main_server_by_string(const char* main_addr) { main_server_address.sin_family = AF_INET; - if ((sock_listener = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - printf("Server: socket init failed - %s\n", strerror(errno)); - exit(1); + if ((server_listen_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + perror("init_main_server_by_string: server_listen_sock init"); + exit(EXIT_FAILURE); } - while (bind(sock_listener, + while (bind(server_listen_sock, (const struct sockaddr*)&local_server_address, sizeof(local_server_address)) < 0) { local_server_address.sin_port++; } memset(main_host_ip, 0, CCL_IP_LEN); - STR_COPY(main_host_ip, main_addr, CCL_IP_LEN); + kvs_str_copy(main_host_ip, main_addr, CCL_IP_LEN); if ((port = strstr(main_host_ip, "_")) == NULL) { - printf("You must set %s like IP_PORT\n", CCL_KVS_IP_PORT_ENV); + printf("init_main_server_by_string: set %s in format _\n", CCL_KVS_IP_PORT_ENV); return 1; } port[0] = '\0'; @@ -473,7 +620,9 @@ size_t init_main_server_by_string(const char* main_addr) { main_server_address.sin_port = main_port; if (inet_pton(AF_INET, main_host_ip, &(main_server_address.sin_addr)) <= 0) { - printf("\nInvalid address/ Address not supported: %s(%s)\n", main_host_ip, strerror(errno)); + printf("init_main_server_by_string: invalid address / address not supported: %s\n", + main_host_ip); + perror("init_main_server_by_string: inet_pton"); return 1; } return 0; @@ -482,12 +631,13 @@ size_t init_main_server_by_string(const char* main_addr) { size_t internal_kvs::kvs_main_server_address_reserve(char* main_address) { FILE* fp; char* additional_local_host_ips; - if ((fp = popen(CHECKER_IP, READ_ONLY)) == NULL) { - printf("Can't get host IP - %s\n", strerror(errno)); - exit(1); + if ((fp = popen(GET_IP_CMD, READ_ONLY)) == NULL) { + perror("reserve_main_address: can not get host IP"); + exit(EXIT_FAILURE); } CHECK_FGETS(fgets(local_host_ip, CCL_IP_LEN, fp), local_host_ip); pclose(fp); + while (local_host_ip[strlen(local_host_ip) - 1] == '\n' || local_host_ip[strlen(local_host_ip) - 1] == ' ') local_host_ip[strlen(local_host_ip) - 1] = NULL_CHAR; @@ -495,22 +645,24 @@ size_t internal_kvs::kvs_main_server_address_reserve(char* main_address) { additional_local_host_ips[0] = NULL_CHAR; if (strlen(local_host_ip) >= CCL_IP_LEN - INT_STR_SIZE - 1) { - printf("Error: Local host IP is too bigger: %zu, expected: %d\n", + printf("reserve_main_address: local host IP is too long: %zu, expected: %d\n", strlen(local_host_ip), CCL_IP_LEN - INT_STR_SIZE - 1); - exit(1); + exit(EXIT_FAILURE); } - if ((sock_listener = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - printf("Server: socket init failed - %s\n", strerror(errno)); - exit(1); + + if ((server_listen_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + perror("reserve_main_address: server_listen_sock init"); + exit(EXIT_FAILURE); } + main_server_address.sin_family = AF_INET; main_server_address.sin_addr.s_addr = inet_addr(local_host_ip); main_server_address.sin_port = 1; local_server_address.sin_family = AF_INET; local_server_address.sin_addr.s_addr = inet_addr(local_host_ip); - while (bind(sock_listener, + while (bind(server_listen_sock, (const struct sockaddr*)&main_server_address, sizeof(main_server_address)) < 0) { main_server_address.sin_port++; @@ -532,10 +684,11 @@ size_t init_main_server_address(const char* main_addr) { FILE* fp; char* additional_local_host_ips; - if ((fp = popen(CHECKER_IP, READ_ONLY)) == NULL) { - printf("Can't get host IP\n"); - exit(1); + if ((fp = popen(GET_IP_CMD, READ_ONLY)) == NULL) { + perror("init_main_server_address: can not get host IP"); + exit(EXIT_FAILURE); } + memset(local_host_ip, 0, CCL_IP_LEN); CHECK_FGETS(fgets(local_host_ip, CCL_IP_LEN, fp), local_host_ip); pclose(fp); @@ -557,14 +710,14 @@ size_t init_main_server_address(const char* main_addr) { ip_getting_mode = IGT_K8S; } else { - printf("Unknown %s: %s\n", CCL_KVS_IP_EXCHANGE_ENV, ip_getting_type); + printf("unknown %s: %s\n", CCL_KVS_IP_EXCHANGE_ENV, ip_getting_type); return 1; } } if (main_addr != NULL) { ip_getting_mode = IGT_ENV; - if (sock_listener == 0) + if (server_listen_sock == 0) init_main_server_by_string(main_addr); return 0; } @@ -575,14 +728,15 @@ size_t init_main_server_address(const char* main_addr) { main_server_address.sin_family = AF_INET; - if ((sock_listener = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - printf("Server: socket init failed - %s\n", strerror(errno)); - exit(1); + if ((server_listen_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + ; + perror("init_main_server_address: server_listen_sock init"); + exit(EXIT_FAILURE); } switch (ip_getting_mode) { case IGT_K8S: { - while (bind(sock_listener, + while (bind(server_listen_sock, (const struct sockaddr*)&local_server_address, sizeof(local_server_address)) < 0) { local_server_address.sin_port++; @@ -610,11 +764,11 @@ size_t init_main_server_address(const char* main_addr) { } } if (is_master_node) { - if (bind(sock_listener, + if (bind(server_listen_sock, (const struct sockaddr*)&main_server_address, sizeof(main_server_address)) < 0) { - printf("PORT %d busy\n", main_server_address.sin_port); - while (bind(sock_listener, + printf("port [%d] is busy\n", main_server_address.sin_port); + while (bind(server_listen_sock, (const struct sockaddr*)&local_server_address, sizeof(local_server_address)) < 0) { local_server_address.sin_port++; @@ -626,7 +780,7 @@ size_t init_main_server_address(const char* main_addr) { } } else { - while (bind(sock_listener, + while (bind(server_listen_sock, (const struct sockaddr*)&local_server_address, sizeof(local_server_address)) < 0) { local_server_address.sin_port++; @@ -637,7 +791,7 @@ size_t init_main_server_address(const char* main_addr) { return res; } default: { - printf("Unknown %s\n", CCL_KVS_IP_EXCHANGE_ENV); + printf("unknown %s\n", CCL_KVS_IP_EXCHANGE_ENV); return 1; } } @@ -657,55 +811,69 @@ size_t internal_kvs::kvs_init(const char* main_addr) { addr.sin_addr.s_addr = inet_addr("127.0.0.1"); addr.sin_port = 1; - if ((sock_sender = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - printf("\n Socket creation error \n"); + if ((client_op_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + perror("kvs_init: client_op_sock init"); return 1; } - if ((local_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - printf("\n Socket creation error: %s\n", strerror(errno)); + + if ((server_control_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + perror("kvs_init: server_control_sock init"); return 1; } + if (init_main_server_address(main_addr)) { - printf("Init main server address error\n"); - close(sock_sender); - close(local_sock); + printf("kvs_init: init main server address error\n"); + close(client_op_sock); + close(server_control_sock); + client_op_sock = 0; + server_control_sock = 0; return 1; } - while (bind(local_sock, (const struct sockaddr*)&addr, sizeof(addr)) < 0) { + + while (bind(server_control_sock, (const struct sockaddr*)&addr, sizeof(addr)) < 0) { addr.sin_port++; } - if (listen(local_sock, 1) < 0) { - printf("listener error: %s\n", strerror(errno)); + if (listen(server_control_sock, 1) < 0) { + perror("kvs_init: server_control_sock listen"); exit(EXIT_FAILURE); } - getsockname(local_sock, (struct sockaddr*)&addr, &len); - err = pthread_create(&thread, NULL, kvs_server_init, &addr); + + getsockname(server_control_sock, (struct sockaddr*)&addr, &len); + err = pthread_create(&kvs_thread, NULL, kvs_server_init, &addr); if (err) { - printf("error while creating listener thread, pthread_create returns %d\n", err); + printf("kvs_init: failed to create kvs server thread, pthread_create returns %d\n", err); return 1; } - if ((accepted_local_sock = accept(local_sock, NULL, NULL)) < 0) { - printf("Client: accept error: %s\n", strerror(errno)); + if ((client_control_sock = accept(server_control_sock, NULL, NULL)) < 0) { + perror("kvs_init: server_control_sock accept"); exit(EXIT_FAILURE); } + /* Wait connection to master */ start_time = time(NULL); do { err = connect( - sock_sender, (struct sockaddr*)&main_server_address, sizeof(main_server_address)); + client_op_sock, (struct sockaddr*)&main_server_address, sizeof(main_server_address)); connection_time = time(NULL) - start_time; } while ((err < 0) && (connection_time < CONNECTION_TIMEOUT)); if (connection_time >= CONNECTION_TIMEOUT) { - printf("Connection error: timeout limit (%ld > %d)\n", connection_time, CONNECTION_TIMEOUT); - exit(1); + printf("kvs_init: connection error: timeout limit (%ld > %d)\n", + connection_time, + CONNECTION_TIMEOUT); + exit(EXIT_FAILURE); } request.mode = AM_CONNECT; - DO_RW_OP(write, sock_sender, &request, sizeof(kvs_request_t)); + DO_RW_OP(write, + client_op_sock, + &request, + sizeof(kvs_request_t), + client_memory_mutex, + "client: connect"); if (strstr(main_host_ip, local_host_ip) && local_port == main_port) { is_master = 1; @@ -719,28 +887,46 @@ size_t internal_kvs::kvs_finalize(void) { kvs_request_t request; memset(&request, 0, sizeof(kvs_request_t)); - if (thread != 0) { + if (kvs_thread != 0) { void* exit_code; int err; request.mode = AM_FINALIZE; - DO_RW_OP(write, accepted_local_sock, &request, sizeof(kvs_request_t)); - - DO_RW_OP(read, accepted_local_sock, &err, sizeof(int)); - - err = pthread_join(thread, &exit_code); + DO_RW_OP(write, + client_control_sock, + &request, + sizeof(kvs_request_t), + client_memory_mutex, + "client: finalize start"); + + DO_RW_OP(read, + client_control_sock, + &err, + sizeof(int), + client_memory_mutex, + "client: finalize complete"); + + err = pthread_join(kvs_thread, &exit_code); if (err) { - printf("error while joining progress listener, pthread_join returns %d\n", err); + printf("kvs_finalize: failed to stop kvs server thread, pthread_join returns %d\n", + err); } - thread = 0; - close(accepted_local_sock); - close(local_sock); + + kvs_thread = 0; + + close(client_control_sock); + close(server_control_sock); + + client_control_sock = 0; + server_control_sock = 0; } - close(sock_sender); + close(client_op_sock); + client_op_sock = 0; if (ip_getting_mode == IGT_K8S) request_k8s_kvs_finalize(is_master); is_inited = false; + return 0; } internal_kvs::~internal_kvs() { diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h index 41ddf1f9e..8c8f5b0ff 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef KVS -#define KVS +#pragma once #include #include "ikvs_wrapper.h" + class internal_kvs final : public ikvs_wrapper { public: size_t kvs_set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val) override; @@ -47,4 +47,3 @@ class internal_kvs final : public ikvs_wrapper { private: bool is_inited{ false }; }; -#endif \ No newline at end of file diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/kvs_common_attr.hpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/kvs_common_attr.hpp new file mode 100644 index 000000000..8f5eee319 --- /dev/null +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/kvs_common_attr.hpp @@ -0,0 +1,53 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/kvs_attr_ids_traits.hpp" + +namespace ccl { + +class ccl_kvs_attr_impl { +public: + /** + * `version` operations + */ + using version_traits_t = detail::ccl_api_type_attr_traits; + + const typename version_traits_t::return_type& get_attribute_value( + const version_traits_t& id) const { + return version; + } + + typename version_traits_t::return_type set_attribute_value(typename version_traits_t::type val, + const version_traits_t& t) { + (void)t; + throw ccl::exception("Set value for 'ccl::kvs_attr_id::version' is not allowed"); + return version; + } + + ccl_kvs_attr_impl(const typename version_traits_t::return_type& version) : version(version) {} + + template + bool is_valid() const noexcept { + return (attr_id == kvs_attr_id::version); + } + +protected: + typename version_traits_t::return_type version; +}; + +} // namespace ccl diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.cpp index 0446fce0a..e648f1fbc 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.cpp @@ -21,28 +21,27 @@ users_kvs::users_kvs(std::shared_ptr kvs) : kvs(kvs) {} size_t users_kvs::kvs_set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val) { - std::string name(kvs_name), key(kvs_key); + ccl::string_class name(kvs_name), key(kvs_key); ccl::vector_class vec_val(kvs_val, kvs_val + strlen(kvs_val) + 1); vec_val[strlen(kvs_val)] = '\0'; - kvs->set((name + key).c_str(), vec_val); + kvs->set(name + key, vec_val); return 0; } size_t users_kvs::kvs_remove_name_key(const char* kvs_name, const char* kvs_key) { ccl::vector_class kvs_val = { '\0' }; - std::string name(kvs_name), key(kvs_key); - kvs->set((name + key).c_str(), kvs_val); - + ccl::string_class name(kvs_name), key(kvs_key); + kvs->set(name + key, kvs_val); return 0; } size_t users_kvs::kvs_get_value_by_name_key(const char* kvs_name, const char* kvs_key, char* kvs_val) { - std::string name(kvs_name), key(kvs_key); + ccl::string_class name(kvs_name), key(kvs_key); + ccl::vector_class res = kvs->get(name + key); - ccl::vector_class res = kvs->get((name + key).c_str()); if (res.data()) SET_STR(kvs_val, MAX_KVS_VAL_LENGTH, "%s", res.data()); else diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h index 4ad3abbac..1d220ebff 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once + #include #include #include "oneapi/ccl.hpp" diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.c b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.cpp similarity index 87% rename from src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.c rename to src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.cpp index 64701f1d0..4f554a752 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.c +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.cpp @@ -15,7 +15,9 @@ */ #include #include -#include "kvs_keeper.h" + +#include "util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp" +#include "kvs_keeper.hpp" #include "def.h" #define COMPARE_STR(str1, str2, str2_len) (strstr((str1), (str2)) && (strlen(str1) == (str2_len))) @@ -56,7 +58,7 @@ size_t get_val(const char kvs_name[], const char kvs_key[], char* kvs_val, stora for (i = 0; i < kvs_list_size[st_type]; i++) { if (COMPARE_STR(new_key_ptr->kvs.name, kvs_name, kvs_name_len) && COMPARE_STR(new_key_ptr->kvs.key, kvs_key, kvs_key_len)) { - STR_COPY(kvs_val, new_key_ptr->kvs.val, MAX_KVS_VAL_LENGTH); + kvs_str_copy(kvs_val, new_key_ptr->kvs.val, MAX_KVS_VAL_LENGTH); return 1; } new_key_ptr = new_key_ptr->next; @@ -91,7 +93,15 @@ size_t get_keys_values(const char* kvs_name, } *kvs_values = (char**)malloc(sizeof(char*) * count); + if (*kvs_values == NULL) { + printf("Memory allocation failed\n"); + exit(1); + } *kvs_keys = (char**)malloc(sizeof(char*) * count); + if (*kvs_keys == NULL) { + printf("Memory allocation failed\n"); + exit(1); + } for (i = 0; i < count; i++) { (*kvs_keys)[i] = (char*)malloc(sizeof(char) * MAX_KVS_KEY_LENGTH); @@ -101,8 +111,8 @@ size_t get_keys_values(const char* kvs_name, new_key_ptr = head[st_type]; for (i = 0; ((new_key_ptr != NULL) && (i < count));) { if (COMPARE_STR(new_key_ptr->kvs.name, kvs_name, kvs_name_len)) { - STR_COPY((*kvs_keys)[i], new_key_ptr->kvs.key, MAX_KVS_KEY_LENGTH); - STR_COPY((*kvs_values)[i], new_key_ptr->kvs.val, MAX_KVS_VAL_LENGTH); + kvs_str_copy((*kvs_keys)[i], new_key_ptr->kvs.key, MAX_KVS_KEY_LENGTH); + kvs_str_copy((*kvs_values)[i], new_key_ptr->kvs.val, MAX_KVS_VAL_LENGTH); i++; } new_key_ptr = new_key_ptr->next; @@ -167,9 +177,9 @@ void put_key(const char kvs_name[], } kvs_list_size[st_type]++; copy: - STR_COPY(tmp_key_ptr->kvs.name, kvs_name, MAX_KVS_NAME_LENGTH); - STR_COPY(tmp_key_ptr->kvs.key, kvs_key, MAX_KVS_KEY_LENGTH); - STR_COPY(tmp_key_ptr->kvs.val, kvs_val, MAX_KVS_VAL_LENGTH); + kvs_str_copy(tmp_key_ptr->kvs.name, kvs_name, MAX_KVS_NAME_LENGTH); + kvs_str_copy(tmp_key_ptr->kvs.key, kvs_key, MAX_KVS_KEY_LENGTH); + kvs_str_copy(tmp_key_ptr->kvs.val, kvs_val, MAX_KVS_VAL_LENGTH); if (strlen(kvs_name) > MAX_KVS_NAME_LENGTH) { tmp_key_ptr->kvs.name[MAX_KVS_NAME_LENGTH - 1] = NULL_CHAR; @@ -203,9 +213,9 @@ size_t cut_head(char* kvs_name, char* kvs_key, char* kvs_val, storage_type_t st_ memset(kvs_name, 0, MAX_KVS_NAME_LENGTH); memset(kvs_key, 0, MAX_KVS_KEY_LENGTH); memset(kvs_val, 0, MAX_KVS_VAL_LENGTH); - STR_COPY(kvs_name, key_ptr->kvs.name, MAX_KVS_NAME_LENGTH); - STR_COPY(kvs_key, key_ptr->kvs.key, MAX_KVS_KEY_LENGTH); - STR_COPY(kvs_val, key_ptr->kvs.val, MAX_KVS_VAL_LENGTH); + kvs_str_copy(kvs_name, key_ptr->kvs.name, MAX_KVS_NAME_LENGTH); + kvs_str_copy(kvs_key, key_ptr->kvs.key, MAX_KVS_KEY_LENGTH); + kvs_str_copy(kvs_val, key_ptr->kvs.val, MAX_KVS_VAL_LENGTH); free(key_ptr); kvs_list_size[st_type]--; diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.hpp similarity index 100% rename from src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.h rename to src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.hpp diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.cpp index 5d9616e42..ca4a2a59a 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.cpp @@ -55,7 +55,7 @@ int pmi_listener::collect_sock_addr(std::shared_ptr h) { char my_ip[MAX_KVS_VAL_LENGTH]; char* point_to_space; - if ((fp = popen(CHECKER_IP, READ_ONLY)) == NULL) { + if ((fp = popen(GET_IP_CMD, READ_ONLY)) == NULL) { printf("Can't get host IP\n"); exit(1); } @@ -92,6 +92,11 @@ int pmi_listener::collect_sock_addr(std::shared_ptr h) { } server_addresses = (struct sockaddr_in*)malloc((num_listeners) * sizeof(struct sockaddr_in)); + if (server_addresses == NULL) { + printf("\nmemory allocation failed \n"); + res = -1; + goto exit; + } /*get listener addresses*/ for (i = 0, j = 0; i < num_listeners; i++, j++) { @@ -107,7 +112,23 @@ int pmi_listener::collect_sock_addr(std::shared_ptr h) { i--; continue; } - server_addresses[i].sin_port = strtol(point_to_port, NULL, 10); + + if ((server_addresses[i].sin_port = strtol(point_to_port, NULL, 10)) == 0) { + /* if a conversion error occurred, display a message and exit */ + if (errno == EINVAL) { + printf("\nconversion error occurred from: %hu\n", server_addresses[i].sin_port); + res = -1; + goto exit; + } + + /* if the value provided was out of range, display a warning message */ + if (errno == ERANGE) { + printf("\nthe value provided was out of range, value: %hu\n", + server_addresses[i].sin_port); + res = -1; + goto exit; + } + } server_addresses[i].sin_family = AF_INET; if (inet_pton(AF_INET, sock_addr_str[j], &(server_addresses[i].sin_addr)) <= 0) { @@ -163,8 +184,9 @@ int pmi_listener::run_listener(std::shared_ptr h) { char* point_to_space; struct timeval timeout; timeout.tv_sec = LISTENER_TIMEOUT; + timeout.tv_usec = 0; - if ((fp = popen(CHECKER_IP, READ_ONLY)) == NULL) { + if ((fp = popen(GET_IP_CMD, READ_ONLY)) == NULL) { printf("Can't get host IP\n"); exit(1); } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.hpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.hpp index 87e44670f..e845ea50d 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.hpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.hpp @@ -16,7 +16,7 @@ #ifndef LISTENER_H_INCLUDED #define LISTENER_H_INCLUDED -#include "helper.h" +#include "helper.hpp" class pmi_listener { public: diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.c b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.cpp similarity index 89% rename from src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.c rename to src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.cpp index f2e567581..712c69562 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.c +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.cpp @@ -16,7 +16,7 @@ #include #include -#include "rank_list.h" +#include "rank_list.hpp" void rank_list_sort(rank_list_t* list) { rank_list_t* left = list; @@ -26,7 +26,7 @@ void rank_list_sort(rank_list_t* list) { right = left->next; while (right != NULL) { if (left->rank > right->rank) { - size_t tmp_i = left->rank; + int tmp_i = left->rank; left->rank = right->rank; right->rank = tmp_i; } @@ -48,7 +48,7 @@ void rank_list_clean(rank_list_t** list) { *list = NULL; } -size_t rank_list_contains(rank_list_t* list, size_t rank) { +size_t rank_list_contains(rank_list_t* list, int rank) { rank_list_t* cur_list = list; while (cur_list != NULL) { @@ -81,9 +81,13 @@ void rank_list_keep_first_n(rank_list_t** origin_list, size_t n) { (*origin_list) = NULL; } -void rank_list_add(rank_list_t** origin_list, size_t rank) { +void rank_list_add(rank_list_t** origin_list, int rank) { if ((*origin_list) == NULL) { (*origin_list) = (rank_list_t*)malloc(sizeof(rank_list_t)); + if ((*origin_list) == NULL) { + printf("Memory allocation failed\n"); + return; + } (*origin_list)->next = NULL; (*origin_list)->rank = rank; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.hpp similarity index 87% rename from src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.h rename to src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.hpp index 97d8a1a81..064e244d6 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.hpp @@ -20,11 +20,11 @@ extern "C" { #endif typedef struct rank_list { - size_t rank; + int rank; struct rank_list* next; } rank_list_t; -size_t rank_list_contains(rank_list_t* list, size_t rank); +size_t rank_list_contains(rank_list_t* list, int rank); void rank_list_clean(rank_list_t** list); @@ -32,7 +32,7 @@ void rank_list_sort(rank_list_t* list); void rank_list_keep_first_n(rank_list_t** origin_list, size_t n); -void rank_list_add(rank_list_t** origin_list, size_t rank); +void rank_list_add(rank_list_t** origin_list, int rank); #ifdef __cplusplus } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.c b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.cpp similarity index 96% rename from src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.c rename to src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.cpp index 32a460be3..b72c8727e 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.c +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.cpp @@ -19,8 +19,9 @@ #include #include -#include "request_wrappers_k8s.h" #include "def.h" +#include "util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp" +#include "request_wrappers_k8s.hpp" #define JOB_NAME "CCL_JOB_NAME" @@ -114,7 +115,7 @@ void json_get_val(FILE* fp, const char** keys, size_t keys_count, char* val) { res[strlen(res) - 1] = '\0'; last_char = res[strlen(res) - 1]; } - STR_COPY(val, res, MAX_KVS_VAL_LENGTH); + kvs_str_copy(val, res, MAX_KVS_VAL_LENGTH); while (fgets(cur_kvs_str, MAX_KVS_STR_LENGTH, fp)) { } } @@ -222,7 +223,10 @@ void get_my_job_name(const char* connect_api_template) { pod_name, get_kvs_val); - fp = popen(run_str, READ_ONLY); + if ((fp = popen(run_str, READ_ONLY)) == NULL) { + printf("Can't get %s", strerror(errno)); + exit(1); + } CHECK_FGETS(fgets(job_name, MAX_KVS_NAME_LENGTH, fp), job_name); pclose(fp); if (job_name[0] == NULL_CHAR) { @@ -355,24 +359,33 @@ size_t request_k8s_kvs_finalize(size_t is_master) { size_t get_by_template(char*** kvs_entry, const char* request, - const char* template, + const char* template_str, int count, int max_count) { FILE* fp; char get_val[REQUEST_POSTFIX_SIZE]; char run_str[RUN_REQUEST_SIZE]; - size_t i; + int i; if (*kvs_entry != NULL) free(*kvs_entry); *kvs_entry = (char**)malloc(sizeof(char*) * count); - for (i = 0; i < count; i++) + if (*kvs_entry == NULL) { + printf("Memory allocation failed\n"); + exit(1); + } + for (i = 0; i < count; i++) { (*kvs_entry)[i] = (char*)malloc(sizeof(char) * max_count); + if ((*kvs_entry)[i] == NULL) { + printf("Memory allocation failed\n"); + exit(1); + } + } i = 0; - SET_STR(get_val, REQUEST_POSTFIX_SIZE, CONCAT_TWO_COMMAND_TEMPLATE, request, template); + SET_STR(get_val, REQUEST_POSTFIX_SIZE, CONCAT_TWO_COMMAND_TEMPLATE, request, template_str); SET_STR(run_str, RUN_REQUEST_SIZE, run_get_template, get_val); if ((fp = popen(run_str, READ_ONLY)) == NULL) { printf("Can't get by template\n"); diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.hpp similarity index 100% rename from src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.h rename to src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.hpp diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.cpp index 23addcc7f..5e947e01d 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.cpp @@ -16,7 +16,7 @@ #include "atl/util/pm/pmi_resizable_rt/pmi_resizable.h" #include "util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h" -static size_t root_rank = 0; +static int root_rank = 0; static size_t is_new_root = 0; static size_t ask_only_framework = 0; static size_t finalized = 0; @@ -31,9 +31,9 @@ void Call_Hard_finilize(int sig) { pmi_object->Hard_finilize(sig); } -kvs_resize_action_t pmi_resizable::default_checker(size_t comm_size) { +kvs_resize_action_t pmi_resizable::default_checker(int comm_size) { char* comm_size_to_start_env; - size_t comm_size_to_start; + int comm_size_to_start; comm_size_to_start_env = getenv(CCL_WORLD_SIZE_ENV); @@ -47,7 +47,7 @@ kvs_resize_action_t pmi_resizable::default_checker(size_t comm_size) { return KVS_RA_WAIT; } -kvs_resize_action_t pmi_resizable::call_resize_fn(size_t comm_size) { +kvs_resize_action_t pmi_resizable::call_resize_fn(int comm_size) { if (resize_function != nullptr) return resize_function(comm_size); @@ -56,8 +56,8 @@ kvs_resize_action_t pmi_resizable::call_resize_fn(size_t comm_size) { int pmi_resizable::PMIR_Update(void) { char up_idx_str[MAX_KVS_VAL_LENGTH]; - size_t prev_new_ranks_count = 0; - size_t prev_killed_ranks_count = 0; + int prev_new_ranks_count = 0; + int prev_killed_ranks_count = 0; int prev_idx = -1; kvs_resize_action_t answer; rank_list_t* dead_up_idx = NULL; @@ -100,7 +100,7 @@ int pmi_resizable::PMIR_Update(void) { // while (int_list_is_contained(killed_ranks, root_rank) == 1) { - size_t old_root = root_rank; + int old_root = root_rank; h->get_new_root(&root_rank); if (my_rank == root_rank && old_root != root_rank) @@ -151,7 +151,8 @@ int pmi_resizable::PMIR_Update(void) { if (!is_first_collect || ask_only_framework == 1) answer = call_resize_fn(count_pods - killed_ranks_count + new_ranks_count); else { - if (h->get_replica_size() != count_pods - killed_ranks_count + new_ranks_count) + if ((int)(h->get_replica_size()) != + count_pods - killed_ranks_count + new_ranks_count) answer = KVS_RA_WAIT; else answer = KVS_RA_RUN; @@ -202,7 +203,7 @@ int pmi_resizable::PMIR_Update(void) { void pmi_resizable::Hard_finilize(int sig) { char rank_str[INT_STR_SIZE]; - SET_STR(rank_str, INT_STR_SIZE, SIZE_T_TEMPLATE, my_rank); + SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); h->set_value(KVS_DEAD_POD, my_hostname, rank_str); @@ -214,7 +215,7 @@ void pmi_resizable::Hard_finilize(int sig) { old_act.sa_handler(sig); } -int pmi_resizable::PMIR_Main_Addr_Reserv(char* main_addr) { +int pmi_resizable::PMIR_Main_Addr_Reserve(char* main_addr) { h->main_server_address_reserve(main_addr); return 0; } @@ -272,7 +273,7 @@ int pmi_resizable::PMIR_Finalize(void) { applied = 0; - SET_STR(rank_str, INT_STR_SIZE, SIZE_T_TEMPLATE, my_rank); + SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); h->remove_name_key(KVS_POD_NUM, rank_str); @@ -313,18 +314,18 @@ int pmi_resizable::PMIR_Barrier(void) { return 0; } -int pmi_resizable::PMIR_Get_size(size_t* size) { +int pmi_resizable::PMIR_Get_size(int* size) { *size = count_pods; return 0; } -int pmi_resizable::PMIR_Get_rank(size_t* rank) { +int pmi_resizable::PMIR_Get_rank(int* rank) { *rank = my_rank; return 0; } int pmi_resizable::PMIR_KVS_Get_my_name(char* kvs_name, size_t length) { - STR_COPY(kvs_name, KVS_NAME, length); + kvs_str_copy(kvs_name, KVS_NAME, length); return 0; } @@ -372,15 +373,15 @@ int pmi_resizable::PMIR_Wait_notification(void) { return listener.run_listener(h); } -size_t pmi_resizable::get_rank() { +int pmi_resizable::get_rank() { return rank; } -size_t pmi_resizable::get_size() { +int pmi_resizable::get_size() { return size; } -size_t pmi_resizable::get_thread() { +size_t pmi_resizable::get_local_thread_idx() { return 0; } size_t pmi_resizable::get_local_kvs_id() { diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.h index 85a6eac0c..b9ef14a62 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.h @@ -44,17 +44,17 @@ typedef enum { KVS_RA_RUN = 1, KVS_RA_FINALIZE = 2, } kvs_resize_action_t; -typedef kvs_resize_action_t (*pmir_resize_fn_t)(size_t comm_size); +typedef kvs_resize_action_t (*pmir_resize_fn_t)(int comm_size); -int PMIR_API PMIR_Main_Addr_Reserv(char* main_addr); +int PMIR_API PMIR_Main_Addr_Reserve(char* main_addr); int PMIR_API PMIR_Init(const char* main_addr); int PMIR_API PMIR_Finalize(void); -int PMIR_API PMIR_Get_size(size_t* size); +int PMIR_API PMIR_Get_size(int* size); -int PMIR_API PMIR_Get_rank(size_t* rank); +int PMIR_API PMIR_Get_rank(int* rank); int PMIR_API PMIR_KVS_Get_my_name(char* kvs_name, size_t length); diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.c b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.cpp similarity index 86% rename from src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.c rename to src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.cpp index da5bdb652..30fe48722 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.c +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.cpp @@ -16,7 +16,7 @@ #include #include -#include "shift_list.h" +#include "shift_list.hpp" void shift_list_clean(shift_list_t** list) { shift_list_t* cur_list = (*list); @@ -29,10 +29,14 @@ void shift_list_clean(shift_list_t** list) { (*list) = NULL; } -void shift_list_add(shift_list_t** list, size_t old_rank, size_t new_rank, change_type_t type) { +void shift_list_add(shift_list_t** list, int old_rank, int new_rank, change_type_t type) { shift_list_t* cur_list; if ((*list) == NULL) { (*list) = (shift_list_t*)malloc(sizeof(shift_list_t)); + if ((*list) == NULL) { + printf("Memory allocation failed\n"); + return; + } cur_list = (*list); } else { diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.hpp similarity index 88% rename from src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.h rename to src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.hpp index 7983f4dcc..59f3d6a5e 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.hpp @@ -27,8 +27,8 @@ typedef enum change_type { } change_type_t; typedef struct shift_rank { - size_t old_rank; - size_t new_rank; + int old_rank; + int new_rank; change_type_t type; } shift_rank_t; @@ -39,7 +39,7 @@ typedef struct shift_list { void shift_list_clean(shift_list_t** list); -void shift_list_add(shift_list_t** list, size_t old_rank, size_t new_rank, change_type_t type); +void shift_list_add(shift_list_t** list, int old_rank, int new_rank, change_type_t type); #ifdef __cplusplus } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_rt.c b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_rt.c index 58b770e2f..4cfc14dc3 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_rt.c +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_rt.c @@ -23,7 +23,7 @@ #include "pm_rt.h" -#define RESIZABLE_PMI_RT_KEY_FORMAT "%s-%zu" +#define RESIZABLE_PMI_RT_KEY_FORMAT "%s-%d" typedef struct resizable_pm_rt_context { pm_rt_desc_t pmrt_desc; @@ -69,7 +69,7 @@ static void resizable_pmirt_barrier(pm_rt_desc_t *pmrt_desc) { static atl_status_t resizable_pmirt_kvs_put(pm_rt_desc_t *pmrt_desc, char *kvs_key, - size_t proc_idx, + int proc_idx, const void *kvs_val, size_t kvs_val_len) { int ret; @@ -109,7 +109,7 @@ static atl_status_t resizable_pmirt_kvs_put(pm_rt_desc_t *pmrt_desc, static atl_status_t resizable_pmirt_kvs_get(pm_rt_desc_t *pmrt_desc, char *kvs_key, - size_t proc_idx, + int proc_idx, void *kvs_val, size_t kvs_val_len) { int ret; @@ -140,7 +140,7 @@ static atl_status_t resizable_pmirt_kvs_get(pm_rt_desc_t *pmrt_desc, return ATL_STATUS_SUCCESS; } -static atl_status_t resizable_pmirt_update(size_t *proc_idx, size_t *proc_count) { +static atl_status_t resizable_pmirt_update(int *proc_idx, int *proc_count) { int ret; ret = PMIR_Update(); if (ret != PMIR_SUCCESS) @@ -184,8 +184,8 @@ pm_rt_kvs_ops_t resizable_kvs_ops = { .get = resizable_pmirt_kvs_get, }; -atl_status_t resizable_pmirt_init(size_t *proc_idx, - size_t *proc_count, +atl_status_t resizable_pmirt_init(int *proc_idx, + int *proc_count, pm_rt_desc_t **pmrt_desc, const char *main_addr) { int ret; @@ -260,8 +260,8 @@ atl_status_t resizable_pmirt_init(size_t *proc_idx, return ATL_STATUS_FAILURE; } -atl_status_t resizable_pmirt_main_addr_reserv(char *main_addr) { - int ret = PMIR_Main_Addr_Reserv(main_addr); +atl_status_t resizable_pmirt_main_addr_reserve(char *main_addr) { + int ret = PMIR_Main_Addr_Reserve(main_addr); if (ret) return ATL_STATUS_FAILURE; diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.cpp index 1a03c6a43..b797f5c3c 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.cpp @@ -16,23 +16,24 @@ #include #include "util/pm/pmi_resizable_rt/pmi_resizable/def.h" -#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.h" +#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.hpp" #include "pmi_resizable_simple.h" #include "util/pm/codec/pm_rt_codec.h" -#define RESIZABLE_PMI_RT_KEY_FORMAT "%s-%zu" -#define DEVICES_PER_THREAD "DEVICES_PER_THREAD" +#define RESIZABLE_PMI_RT_KEY_FORMAT "%s-%d" +#define RANKS_PER_THREAD "RANKS_PER_THREAD" #define PROCESS_THREAD_NAME "PROCESS_THREAD_NAME" -#define REQUESTED_RANK_TO_NAME "REQUESTED_RANK_TO_NAME" -#define GLOBAL_NAME_TO_RANK "GLOBAL_NAME_TO_RANK" -#define GLOBAL_RANK_TO_NAME "GLOBAL_RANK_TO_NAME" -#define LOCAL_KVS_ID "LOCAL_KVS_ID" -pmi_resizable_simple::pmi_resizable_simple(size_t size, - const std::vector& ranks, +#define REQUESTED_RANK_TO_NAME "REQUESTED_RANK_TO_NAME" +#define GLOBAL_NAME_TO_RANK "GLOBAL_NAME_TO_RANK" +#define GLOBAL_RANK_TO_NAME "GLOBAL_RANK_TO_NAME" +#define LOCAL_KVS_ID "LOCAL_KVS_ID" + +pmi_resizable_simple::pmi_resizable_simple(int size, + const std::vector& ranks, std::shared_ptr k, const char* main_addr) - : dev_count(size), + : total_rank_count(size), ranks(ranks), k(k) { max_keylen = MAX_KVS_KEY_LENGTH; @@ -47,14 +48,14 @@ int pmi_resizable_simple::is_pm_resize_enabled() { atl_status_t pmi_resizable_simple::pmrt_init(const char* main_addr) { (void)main_addr; char* connection_timeout_str = getenv("CCL_KVS_GET_TIMEOUT"); - if (connection_timeout_str) - { + if (connection_timeout_str) { connection_timeout = atoi(connection_timeout_str); } local_id = 0; val_storage = (char*)calloc(1, max_vallen); if (!val_storage) return ATL_STATUS_FAILURE; + /*TODO: add sort, ranks should increase continiusly*/ if (ranks[0] == 0) { size_t tmp_local_id = get_local_kvs_id(); tmp_local_id++; @@ -68,18 +69,18 @@ atl_status_t pmi_resizable_simple::pmrt_init(const char* main_addr) { } void pmi_resizable_simple::make_requested_info() { - register_my_first_rank_and_dev_count(); - get_requested_thread_num_and_threads_count(); + register_first_rank_idx_and_rank_count(); + assign_thread_idx_and_fill_ranks_per_thread_map(); local_id = get_local_kvs_id(); register_my_proc_name(); - get_my_proc_num_and_proc_count(); - get_local_thread_num(); + get_my_proc_idx_and_proc_count(); + calculate_local_thread_idx(); remove_initial_data(); - pmrt_barrier(); + pmrt_barrier_full(); } -atl_status_t pmi_resizable_simple::pmrt_main_addr_reserv(char* main_addr) { +atl_status_t pmi_resizable_simple::pmrt_main_addr_reserve(char* main_addr) { return ATL_STATUS_UNSUPPORTED; } @@ -99,9 +100,9 @@ void pmi_resizable_simple::pmrt_finalize() { is_finalized = true; free(val_storage); - if (getenv("CCL_PMI_FORCE_FINALIZE")) - { - printf("skip pmi_resizable_simple::pmrt_finalize\n"); fflush(stdout); + if (getenv("CCL_PMI_FORCE_FINALIZE")) { + printf("skip pmi_resizable_simple::pmrt_finalize\n"); + fflush(stdout); return; } @@ -120,7 +121,7 @@ void pmi_resizable_simple::pmrt_barrier() { SET_STR(barrier_num_str, INT_STR_SIZE, SIZE_T_TEMPLATE, barrier_num); - kvs_set_value(KVS_BARRIER, std::to_string(requested_rank_num).c_str(), barrier_num_str); + kvs_set_value(KVS_BARRIER, std::to_string(assigned_proc_idx).c_str(), barrier_num_str); min_barrier_num = get_barrier_idx(); while (min_barrier_num != barrier_num) { @@ -131,9 +132,44 @@ void pmi_resizable_simple::pmrt_barrier() { if (barrier_num > BARRIER_NUM_MAX) barrier_num = 0; } +void pmi_resizable_simple::pmrt_barrier_full() { + size_t min_barrier_num; + char barrier_num_str[INT_STR_SIZE]; + + SET_STR(barrier_num_str, INT_STR_SIZE, SIZE_T_TEMPLATE, barrier_num_full); + + kvs_set_value(KVS_BARRIER_FULL, std::to_string(assigned_thread_idx).c_str(), barrier_num_str); + + min_barrier_num = get_barrier_full_idx(); + while (min_barrier_num != barrier_num) { + min_barrier_num = get_barrier_idx(); + } + + barrier_num_full++; + if (barrier_num_full > BARRIER_NUM_MAX) + barrier_num_full = 0; +} + +size_t pmi_resizable_simple::get_barrier_full_idx() { + size_t thread_count = ranks_per_thread_map.size(); + + kvs_get_value(KVS_BARRIER_FULL, std::to_string(0).c_str(), val_storage); + size_t min_barrier_idx = atoi(val_storage); + size_t barrier_idx; + for (size_t i = 1; i < thread_count; i++) { + kvs_get_value(KVS_BARRIER_FULL, std::to_string(i).c_str(), val_storage); + + barrier_idx = atoi(val_storage); + + if (min_barrier_idx > barrier_idx) + min_barrier_idx = barrier_idx; + } + + return min_barrier_idx; +} atl_status_t pmi_resizable_simple::pmrt_kvs_put(char* kvs_key, - size_t proc_idx, + int proc_idx, const void* kvs_val, size_t kvs_val_len) { int ret; @@ -155,7 +191,7 @@ atl_status_t pmi_resizable_simple::pmrt_kvs_put(char* kvs_key, } atl_status_t pmi_resizable_simple::pmrt_kvs_get(char* kvs_key, - size_t proc_idx, + int proc_idx, void* kvs_val, size_t kvs_val_len) { int ret; @@ -174,16 +210,16 @@ atl_status_t pmi_resizable_simple::pmrt_kvs_get(char* kvs_key, return ATL_STATUS_SUCCESS; } -size_t pmi_resizable_simple::get_size() { - return threads_per_rank.size(); +int pmi_resizable_simple::get_size() { + return threads_per_proc.size(); } -size_t pmi_resizable_simple::get_rank() { - return requested_rank_num; +int pmi_resizable_simple::get_rank() { + return assigned_proc_idx; } -size_t pmi_resizable_simple::get_thread() { - return local_thread_num; +size_t pmi_resizable_simple::get_local_thread_idx() { + return local_thread_idx; } int pmi_resizable_simple::kvs_set_value(const char* kvs_name, const char* key, const char* value) { @@ -199,12 +235,15 @@ int pmi_resizable_simple::kvs_get_value(const char* kvs_name, const char* key, c size_t connection_time = 0; start_time = time(NULL); while (k->kvs_get_value_by_name_key(result_kvs_name.c_str(), key, value) == 0 && - connection_time < connection_timeout) { + connection_time < connection_timeout) { connection_time = time(NULL) - start_time; } if (connection_time >= connection_timeout) { printf("KVS get error: timeout limit (%zu > %zu), prefix: %s, key %s\n", - connection_time, connection_timeout, result_kvs_name.c_str(), key); + connection_time, + connection_timeout, + result_kvs_name.c_str(), + key); exit(1); } @@ -214,10 +253,9 @@ int pmi_resizable_simple::kvs_get_value(const char* kvs_name, const char* key, c int pmi_resizable_simple::kvs_iget_value(const char* kvs_name, const char* key, char* value) { std::string result_kvs_name = std::string(kvs_name) + std::to_string(local_id); return k->kvs_get_value_by_name_key(result_kvs_name.c_str(), key, value); - ; } size_t pmi_resizable_simple::get_barrier_idx() { - size_t proc_count = threads_per_rank.size(); + size_t proc_count = threads_per_proc.size(); kvs_get_value(KVS_BARRIER, std::to_string(0).c_str(), val_storage); @@ -235,23 +273,23 @@ size_t pmi_resizable_simple::get_barrier_idx() { return min_barrier_idx; } -void pmi_resizable_simple::register_my_first_rank_and_dev_count() { +void pmi_resizable_simple::register_first_rank_idx_and_rank_count() { kvs_set_value( - DEVICES_PER_THREAD, std::to_string(ranks[0]).c_str(), std::to_string(ranks.size()).c_str()); + RANKS_PER_THREAD, std::to_string(ranks[0]).c_str(), std::to_string(ranks.size()).c_str()); } -void pmi_resizable_simple::get_requested_thread_num_and_threads_count() { - size_t total_dev_count = 0; - size_t devises; - while (total_dev_count < dev_count) { - if (total_dev_count == ranks[0]) { - requested_thread_num = devises_per_thread.size(); +void pmi_resizable_simple::assign_thread_idx_and_fill_ranks_per_thread_map() { + int rank_count = 0; + int ranks_per_thread; + while (rank_count < total_rank_count) { + if (rank_count == ranks[0]) { + assigned_thread_idx = ranks_per_thread_map.size(); } - kvs_get_value(DEVICES_PER_THREAD, std::to_string(total_dev_count).c_str(), val_storage); + kvs_get_value(RANKS_PER_THREAD, std::to_string(rank_count).c_str(), val_storage); - devises = atoi(val_storage); - devises_per_thread.push_back(devises); - total_dev_count += devises; + ranks_per_thread = atoi(val_storage); + ranks_per_thread_map.push_back(ranks_per_thread); + rank_count += ranks_per_thread; } } @@ -266,46 +304,45 @@ void pmi_resizable_simple::register_my_proc_name() { } my_proccess_name = std::string(hostname) + std::to_string(my_pid); - kvs_set_value(PROCESS_THREAD_NAME, - std::to_string(requested_thread_num).c_str(), - my_proccess_name.c_str()); + kvs_set_value( + PROCESS_THREAD_NAME, std::to_string(assigned_thread_idx).c_str(), my_proccess_name.c_str()); } -void pmi_resizable_simple::get_my_proc_num_and_proc_count() { - std::map proc_name_to_rank; - std::map::iterator it; - size_t rank; - for (size_t i = 0; i < devises_per_thread.size(); i++) { +void pmi_resizable_simple::get_my_proc_idx_and_proc_count() { + std::map proc_name_to_rank; + std::map::iterator it; + int rank; + for (size_t i = 0; i < ranks_per_thread_map.size(); i++) { kvs_get_value(PROCESS_THREAD_NAME, std::to_string(i).c_str(), val_storage); it = proc_name_to_rank.find(val_storage); if (it == proc_name_to_rank.end()) { - rank = threads_per_rank.size(); + rank = threads_per_proc.size(); if (!my_proccess_name.compare(val_storage)) { - requested_rank_num = rank; - if (requested_thread_num == i) { + assigned_proc_idx = rank; + if (assigned_thread_idx == i) { kvs_set_value(REQUESTED_RANK_TO_NAME, - std::to_string(requested_rank_num).c_str(), + std::to_string(assigned_proc_idx).c_str(), my_proccess_name.c_str()); } } proc_name_to_rank[val_storage] = rank; - threads_per_rank[rank].push_back(i); + threads_per_proc[rank].push_back(i); } else { - threads_per_rank[it->second].push_back(i); + threads_per_proc[it->second].push_back(i); } } } -void pmi_resizable_simple::get_local_thread_num() { - local_thread_num = 0; - for (auto it = threads_per_rank[requested_rank_num].begin(); - it != threads_per_rank[requested_rank_num].end(); +void pmi_resizable_simple::calculate_local_thread_idx() { + local_thread_idx = 0; + for (auto it = threads_per_proc[assigned_proc_idx].begin(); + it != threads_per_proc[assigned_proc_idx].end(); it++) { - if (requested_thread_num == *it) + if (assigned_thread_idx == *it) break; - local_thread_num++; + local_thread_idx++; } } @@ -314,7 +351,7 @@ void pmi_resizable_simple::make_map_requested2global() { char process_name[MAX_KVS_VAL_LENGTH]; size_t size = get_size(); requested2global.resize(size); - pmrt_barrier(); + pmrt_barrier_full(); for (size_t i = 0; i < size; i++) { kvs_get_value(REQUESTED_RANK_TO_NAME, std::to_string(i).c_str(), process_name); if (kvs_iget_value(GLOBAL_NAME_TO_RANK, process_name, global_rank_str) == 0) { @@ -336,7 +373,7 @@ void pmi_resizable_simple::make_map_requested2global() { } requested2global[i] = atoi(global_rank_str); } - pmrt_barrier(); + pmrt_barrier_full(); } size_t pmi_resizable_simple::get_local_kvs_id() { @@ -357,7 +394,7 @@ pmi_resizable_simple::~pmi_resizable_simple() { pmrt_finalize(); } void pmi_resizable_simple::remove_initial_data() { - std::string result_kvs_name = std::string(DEVICES_PER_THREAD) + std::to_string(0); + std::string result_kvs_name = std::string(RANKS_PER_THREAD) + std::to_string(0); remove_val(result_kvs_name.c_str(), std::to_string(ranks[0]).c_str(), ST_CLIENT); k->kvs_remove_name_key(result_kvs_name.c_str(), std::to_string(ranks[0]).c_str()); } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.h index 9f3544198..86d0a2f04 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.h @@ -41,8 +41,8 @@ class pmi_resizable_simple final : public ipmi { public: pmi_resizable_simple() = delete; - pmi_resizable_simple(size_t dev_count, - const std::vector& ranks, + pmi_resizable_simple(int total_rank_count, + const std::vector& ranks, std::shared_ptr k, const char* main_addr = nullptr); @@ -50,7 +50,7 @@ class pmi_resizable_simple final : public ipmi { int is_pm_resize_enabled() override; - atl_status_t pmrt_main_addr_reserv(char* main_addr) override; + atl_status_t pmrt_main_addr_reserve(char* main_addr) override; atl_status_t pmrt_set_resize_function(atl_resize_fn_t resize_fn) override; @@ -61,68 +61,77 @@ class pmi_resizable_simple final : public ipmi { void pmrt_barrier() override; atl_status_t pmrt_kvs_put(char* kvs_key, - size_t proc_idx, + int proc_idx, const void* kvs_val, size_t kvs_val_len) override; atl_status_t pmrt_kvs_get(char* kvs_key, - size_t proc_idx, + int proc_idx, void* kvs_val, size_t kvs_val_len) override; - size_t get_size() override; + int get_size() override; - size_t get_rank() override; + int get_rank() override; - size_t get_thread() override; + size_t get_local_thread_idx() override; size_t get_local_kvs_id() override; void set_local_kvs_id(size_t local_kvs_id) override; - size_t get_threads_count() override { - return threads_per_rank[requested_rank_num].size(); + size_t get_threads_per_process() override { + return threads_per_proc[assigned_proc_idx].size(); } - size_t get_devices_per_rank_count() override { + size_t get_ranks_per_process() override { size_t res = 0; - std::list& threads = threads_per_rank[requested_rank_num]; - for (auto it = threads.begin(); it != threads.end(); it++) { - res += devises_per_thread[*it]; + std::list& thread_idxs = threads_per_proc[assigned_proc_idx]; + for (auto it = thread_idxs.begin(); it != thread_idxs.end(); it++) { + res += ranks_per_thread_map[*it]; } return res; } + void pmrt_finalize() override; private: bool is_finalized{ false }; atl_status_t pmrt_init(const char* main_addr = nullptr); + int kvs_set_value(const char* kvs_name, const char* key, const char* value); int kvs_get_value(const char* kvs_name, const char* key, char* value); int kvs_iget_value(const char* kvs_name, const char* key, char* value); + size_t get_barrier_idx(); - void register_my_first_rank_and_dev_count(); - void get_requested_thread_num_and_threads_count(); + size_t get_barrier_full_idx(); + + void calculate_local_thread_idx(); + void register_first_rank_idx_and_rank_count(); + void assign_thread_idx_and_fill_ranks_per_thread_map(); void register_my_proc_name(); - void get_my_proc_num_and_proc_count(); - void get_local_thread_num(); + void get_my_proc_idx_and_proc_count(); void make_requested_info(); void remove_initial_data(); void make_map_requested2global(); - size_t dev_count; - size_t requested_rank_num; - size_t requested_thread_num; - size_t local_thread_num; + void pmrt_barrier_full(); + + int total_rank_count; + int assigned_proc_idx; + + size_t assigned_thread_idx; + size_t local_thread_idx; std::string my_proccess_name; - std::vector ranks; - std::vector devises_per_thread; - std::map> threads_per_rank; + std::vector ranks; + std::vector ranks_per_thread_map; + std::map> threads_per_proc; std::shared_ptr k; size_t max_keylen; size_t max_vallen; char* val_storage = nullptr; size_t barrier_num = 0; - std::vector requested2global; + size_t barrier_num_full = 0; + std::vector requested2global; size_t local_id; - size_t connection_timeout = 120; + size_t connection_timeout = 120; /* in seconds */ }; diff --git a/src/atl/util/pm/pmi_rt/pmi/CMakeLists.txt b/src/atl/util/pm/pmi_rt/pmi/CMakeLists.txt index 2880ae331..5917c927f 100755 --- a/src/atl/util/pm/pmi_rt/pmi/CMakeLists.txt +++ b/src/atl/util/pm/pmi_rt/pmi/CMakeLists.txt @@ -13,35 +13,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # -#builds pmi - -set(PMI_SRC - simple_pmiutil.c - simple_pmi.c) - -set(COMMON_PMI_INC_DIRS - ${PROJECT_SOURCE_DIR}/src/atl/util/pm/pmi_rt/pmi) - -#special library that holds objects only -add_library(pmi-objects OBJECT ${PMI_SRC}) -set_target_properties(pmi-objects PROPERTIES POSITION_INDEPENDENT_CODE 1) -target_include_directories(pmi-objects PUBLIC ${COMMON_PMI_INC_DIRS}) -target_compile_definitions(pmi-objects PRIVATE HAVE_UNISTD_H HAVE_STDLIB_H HAVE_STRING_H HAVE_STRINGS_H) - -#shared lib -add_library(pmi SHARED $) -target_include_directories(pmi PUBLIC INTERFACE ${COMMON_PMI_INC_DIRS}) -if (NOT LIB_PMI_SO_VERSION AND NOT LIB_PMI_MAJOR_VERSION) - set_target_properties(pmi PROPERTIES VERSION 1 SOVERSION 1.0) -else() - set_target_properties(pmi PROPERTIES VERSION ${LIB_PMI_SO_VERSION} SOVERSION ${LIB_PMI_MAJOR_VERSION}) -endif() -set_target_properties(pmi PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) - -install(TARGETS pmi LIBRARY DESTINATION ${CCL_INSTALL_LIB}) - -#static lib -add_library(pmi-static STATIC $) -set_target_properties(pmi-static PROPERTIES OUTPUT_NAME pmi) -set_target_properties(pmi-static PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) -install(TARGETS pmi-static ARCHIVE DESTINATION ${CCL_INSTALL_LIB}) +#builds pmi + +set(PMI_SRC + simple_pmiutil.c + simple_pmi.c) + +set(COMMON_PMI_INC_DIRS + ${PROJECT_SOURCE_DIR}/src/atl/util/pm/pmi_rt/pmi) + +#special library that holds objects only +add_library(pmi-objects OBJECT ${PMI_SRC}) +set_target_properties(pmi-objects PROPERTIES POSITION_INDEPENDENT_CODE 1) +target_include_directories(pmi-objects PUBLIC ${COMMON_PMI_INC_DIRS}) +target_compile_definitions(pmi-objects PRIVATE HAVE_UNISTD_H HAVE_STDLIB_H HAVE_STRING_H HAVE_STRINGS_H) + +#shared lib +add_library(pmi SHARED $) +target_include_directories(pmi PUBLIC INTERFACE ${COMMON_PMI_INC_DIRS}) +if (NOT LIB_PMI_SO_VERSION AND NOT LIB_PMI_MAJOR_VERSION) + set_target_properties(pmi PROPERTIES VERSION 1 SOVERSION 1.0) +else() + set_target_properties(pmi PROPERTIES VERSION ${LIB_PMI_SO_VERSION} SOVERSION ${LIB_PMI_MAJOR_VERSION}) +endif() +set_target_properties(pmi PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) + +install(TARGETS pmi LIBRARY DESTINATION ${CCL_INSTALL_LIB}) + +#static lib +add_library(pmi-static STATIC $) +set_target_properties(pmi-static PROPERTIES OUTPUT_NAME pmi) +set_target_properties(pmi-static PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${CCL_BUILD_DIR}) +install(TARGETS pmi-static ARCHIVE DESTINATION ${CCL_INSTALL_LIB}) diff --git a/src/atl/util/pm/pmi_rt/pmi/simple_pmiutil.c b/src/atl/util/pm/pmi_rt/pmi/simple_pmiutil.c index 102d301d8..59702c88f 100644 --- a/src/atl/util/pm/pmi_rt/pmi/simple_pmiutil.c +++ b/src/atl/util/pm/pmi_rt/pmi/simple_pmiutil.c @@ -94,9 +94,17 @@ void PMIU_printf(int print_flag, const char *fmt, ...) { if (p) { MPL_snprintf(filename, sizeof(filename), "testclient-%s.out", p); logfile = fopen(filename, "w"); + if (logfile == NULL) { + printf("Error opening file %s \n", strerror(errno)); + return; + } } else { logfile = fopen("testserver.out", "w"); + if (logfile == NULL) { + printf("Error opening file %s \n", strerror(errno)); + return; + } } } else diff --git a/src/atl/util/pm/pmi_rt/pmi_rt.c b/src/atl/util/pm/pmi_rt/pmi_rt.c index 3d8dd52e9..836cb95cc 100644 --- a/src/atl/util/pm/pmi_rt/pmi_rt.c +++ b/src/atl/util/pm/pmi_rt/pmi_rt.c @@ -21,7 +21,7 @@ #include "util/pm/pm_rt.h" -#define PMI_RT_KEY_FORMAT "%s-%zu" +#define PMI_RT_KEY_FORMAT "%s-%d" typedef struct pmi_pm_rt_context { pm_rt_desc_t pmrt_desc; @@ -58,7 +58,7 @@ static void pmirt_finalize(pm_rt_desc_t *pmrt_desc) { static atl_status_t pmirt_kvs_put(pm_rt_desc_t *pmrt_desc, char *kvs_key, - size_t proc_idx, + int proc_idx, const void *kvs_val, size_t kvs_val_len) { int ret; @@ -96,7 +96,7 @@ static atl_status_t pmirt_kvs_put(pm_rt_desc_t *pmrt_desc, static atl_status_t pmirt_kvs_get(pm_rt_desc_t *pmrt_desc, char *kvs_key, - size_t proc_idx, + int proc_idx, void *kvs_val, size_t kvs_val_len) { int ret; @@ -137,7 +137,7 @@ static void pmirt_barrier(pm_rt_desc_t *pmrt_desc) { (void)PMI_Barrier(); } -atl_status_t pmirt_update(size_t *proc_idx, size_t *proc_count) { +atl_status_t pmirt_update(int *proc_idx, int *proc_count) { PMI_Get_size((int *)proc_idx); PMI_Get_rank((int *)proc_count); return ATL_STATUS_SUCCESS; @@ -159,7 +159,7 @@ pm_rt_kvs_ops_t kvs_ops = { .get = pmirt_kvs_get, }; -atl_status_t pmirt_init(size_t *proc_idx, size_t *proc_count, pm_rt_desc_t **pmrt_desc) { +atl_status_t pmirt_init(int *proc_idx, int *proc_count, pm_rt_desc_t **pmrt_desc) { int ret, spawned, max_kvsnamelen; int proc_idx_tmp, proc_count_tmp; diff --git a/src/atl/util/pm/pmi_rt/pmi_simple.cpp b/src/atl/util/pm/pmi_rt/pmi_simple.cpp index 9ac01602c..6e14b8529 100644 --- a/src/atl/util/pm/pmi_rt/pmi_simple.cpp +++ b/src/atl/util/pm/pmi_rt/pmi_simple.cpp @@ -24,7 +24,7 @@ pmi_simple::pmi_simple() { pmirt_init(&rank, &size, &pmrt_desc); } -atl_status_t pmi_simple::pmrt_main_addr_reserv(char *main_addr) { +atl_status_t pmi_simple::pmrt_main_addr_reserve(char *main_addr) { printf("Function main_addr_reserv unsupported yet for simple pmi\n"); return ATL_STATUS_FAILURE; } @@ -54,28 +54,28 @@ void pmi_simple::pmrt_barrier() { } atl_status_t pmi_simple::pmrt_kvs_put(char *kvs_key, - size_t proc_idx, + int proc_idx, const void *kvs_val, size_t kvs_val_len) { return pmirt_kvs_put(pmrt_desc, kvs_key, proc_idx, kvs_val, kvs_val_len); } atl_status_t pmi_simple::pmrt_kvs_get(char *kvs_key, - size_t proc_idx, + int proc_idx, void *kvs_val, size_t kvs_val_len) { return pmirt_kvs_get(pmrt_desc, kvs_key, proc_idx, kvs_val, kvs_val_len); } -size_t pmi_simple::get_rank() { +int pmi_simple::get_rank() { return rank; } -size_t pmi_simple::get_size() { +int pmi_simple::get_size() { return size; } -size_t pmi_simple::get_thread() { +size_t pmi_simple::get_local_thread_idx() { return 0; } size_t pmi_simple::get_local_kvs_id() { diff --git a/src/atl/util/pm/pmi_rt/pmi_simple.h b/src/atl/util/pm/pmi_rt/pmi_simple.h index 87d50d40a..27d8b0571 100644 --- a/src/atl/util/pm/pmi_rt/pmi_simple.h +++ b/src/atl/util/pm/pmi_rt/pmi_simple.h @@ -23,7 +23,7 @@ class pmi_simple final : public ipmi { int is_pm_resize_enabled() override; - atl_status_t pmrt_main_addr_reserv(char *main_addr) override; + atl_status_t pmrt_main_addr_reserve(char *main_addr) override; atl_status_t pmrt_set_resize_function(atl_resize_fn_t resize_fn) override; @@ -36,36 +36,36 @@ class pmi_simple final : public ipmi { void pmrt_barrier() override; atl_status_t pmrt_kvs_put(char *kvs_key, - size_t proc_idx, + int proc_idx, const void *kvs_val, size_t kvs_val_len) override; atl_status_t pmrt_kvs_get(char *kvs_key, - size_t proc_idx, + int proc_idx, void *kvs_val, size_t kvs_val_len) override; - size_t get_rank() override; + int get_rank() override; - size_t get_size() override; + int get_size() override; - size_t get_thread() override; + size_t get_local_thread_idx() override; size_t get_local_kvs_id() override; void set_local_kvs_id(size_t local_kvs_id) override; - size_t get_threads_count() override { + size_t get_threads_per_process() override { return 1; } - size_t get_devices_per_rank_count() override { + size_t get_ranks_per_process() override { return 1; } private: - size_t rank; - size_t size; + int rank; + int size; pm_rt_desc_t *pmrt_desc = nullptr; bool is_finalized{ false }; }; diff --git a/src/ccl.cpp b/src/ccl.cpp deleted file mode 100644 index 8bf262359..000000000 --- a/src/ccl.cpp +++ /dev/null @@ -1,475 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "common/global/global.hpp" -#include "common/stream/stream.hpp" -#include "exec/exec.hpp" - -// ccl_status_t ccl_set_resize_fn(ccl_resize_fn_t callback) -// { -// CCL_CHECK_IS_BLOCKED(); -// try -// { -// return ccl::global_data::get().executor->create_listener(callback); -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t ccl_init() -// { -// try -// { -// ccl::global_data::get().init(); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t ccl_finalize() -// { -// try -// { -// ccl::global_data::get().reset(); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_get_version(ccl::library_version* version) -// { -// if (!version) -// { -// return ccl_status_invalid_arguments; -// } - -// version->major = CCL_MAJOR_VERSION; -// version->minor = CCL_MINOR_VERSION; -// version->update = CCL_UPDATE_VERSION; -// version->product_status = CCL_PRODUCT_STATUS; -// version->build_date = CCL_PRODUCT_BUILD_DATE; -// version->full = CCL_PRODUCT_FULL; - -// return ccl_status_success; -// } - -// ccl_status_t CCL_API ccl_wait(ccl_request_t req) -// { -// CCL_CHECK_IS_BLOCKED(); -// try -// { -// if (!req) -// { -// LOG_ERROR("empty request"); -// return ccl_status_success; -// } - -// auto request = static_cast(req); -// ccl_wait_impl(ccl::global_data::get().executor.get(), request); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_test(ccl_request_t req, int* is_completed) -// { -// CCL_CHECK_IS_BLOCKED(); -// try -// { -// if (!req) -// { -// LOG_ERROR("empty request"); -// if (is_completed) -// { -// *is_completed = 1; -// } -// return ccl_status_success; -// } - -// auto request = static_cast(req); -// auto completed = ccl_test_impl(ccl::global_data::get().executor.get(), request); -// *is_completed = static_cast(completed); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t ccl_comm_create(ccl_comm_t* comm, const ccl_comm_attr_t* attr) -// { -// CCL_CHECK_IS_BLOCKED(); -// CCL_ASSERT(comm); -// try -// { -// ccl::global_data& data = ccl::global_data::get(); -// ccl_comm* comm_ptr = nullptr; - -// if (!attr) -// { -// LOG_DEBUG("create communicator as copy of global communicator"); -// comm_ptr = new ccl_comm(data.comm->rank(), -// data.comm->size(), -// data.comm_ids->acquire(), -// ccl::global_data::get().atl); -// } -// else -// { -// LOG_DEBUG("create communicator with coll_attr"); -// comm_ptr = ccl_comm::create_with_color(attr->color, -// data.comm_ids.get(), -// data.comm.get()); -// } - -// *comm = static_cast(comm_ptr); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t ccl_comm_free(ccl_comm_t comm) -// { -// CCL_CHECK_IS_BLOCKED(); -// CCL_ASSERT(comm); -// LOG_DEBUG("free communicator ", comm); -// try -// { -// delete static_cast(comm); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_get_comm_rank(ccl_comm_t comm, size_t* rank) -// { -// CCL_CHECK_IS_BLOCKED(); -// if (!rank) -// return ccl_status_invalid_arguments; - -// try -// { -// auto comm_ptr = (comm) ? static_cast(comm) : ccl::global_data::get().comm.get(); -// *rank = comm_ptr->rank(); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_get_comm_size(ccl_comm_t comm, size_t* size) -// { -// CCL_CHECK_IS_BLOCKED(); -// if (!size) -// return ccl_status_invalid_arguments; - -// try -// { -// auto comm_ptr = (comm) ? static_cast(comm) : ccl::global_data::get().comm.get(); -// *size = comm_ptr->size(); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t ccl_datatype_create(ccl_datatype_t* dtype, const ccl_datatype_attr_t* attr) -// { -// CCL_CHECK_IS_BLOCKED(); -// CCL_ASSERT(dtype); -// LOG_DEBUG("create datatype"); -// try -// { -// *dtype = ccl::global_data::get().dtypes->create(attr); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_get_datatype_size(ccl_datatype_t dtype, size_t* size) -// { -// CCL_CHECK_IS_BLOCKED(); -// if (!size) -// return ccl_status_invalid_arguments; - -// try -// { -// *size = ccl::global_data::get().dtypes->get(dtype).size(); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_datatype_free(ccl_datatype_t dtype) -// { -// CCL_CHECK_IS_BLOCKED(); -// LOG_DEBUG("free datatype ", dtype); -// try -// { -// ccl::global_data::get().dtypes->free(dtype); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t ccl_stream_create(ccl_stream_type_t type, -// void* native_stream, -// ccl_stream_t* stream) -// { -// CCL_CHECK_IS_BLOCKED(); -// CCL_ASSERT(stream); -// try -// { -// //TODO -// #if 0 -// LOG_DEBUG("create stream by type: ", type); -// #ifdef MULTI_GPU_SUPPORT -// #ifdef CCL_ENABLE_SYCL -// *stream = static_cast(stream_provider_dispatcher::create(*static_cast(native_stream)).release()); -// #else -// *stream = static_cast(stream_provider_dispatcher::create(*static_cast(native_stream)).release()); -// #endif -// #else -// #ifdef CCL_ENABLE_SYCL -// if( type != ccl_stream_host) -// { -// *stream = static_cast(stream_provider_dispatcher::create(*static_cast(native_stream)).release()); -// } -// else -// #endif -// { -// *stream = static_cast(stream_provider_dispatcher::create(native_stream).release()); -// } - -// //for legacy stream: override type for 'host' related queue -// static_cast(*stream)->type = type; -// #endif -// #endif -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t ccl_stream_free(ccl_stream_t stream) -// { -// CCL_CHECK_IS_BLOCKED(); -// CCL_ASSERT(stream); -// LOG_DEBUG("free stream ", stream); -// try -// { -// delete static_cast(stream); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_allgatherv( -// const void* send_buf, -// size_t send_count, -// void* recv_buf, -// const size_t* recv_counts, -// ccl_datatype_t dtype, -// const ccl_coll_attr_t* attr, -// ccl_comm_t comm, -// ccl_stream_t stream, -// ccl_request_t* req) -// { -// CCL_CHECK_IS_BLOCKED(); -// try -// { -// if (!req) -// { -// return ccl_status_invalid_arguments; -// } -// auto request = ccl_allgatherv_impl(send_buf, send_count, recv_buf, recv_counts, dtype, attr, -// (comm) ? static_cast(comm) : ccl::global_data::get().comm.get(), -// static_cast(stream)); -// *req = static_cast(request); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_allreduce( -// const void* send_buf, -// void* recv_buf, -// size_t count, -// ccl_datatype_t dtype, -// ccl_reduction_t reduction, -// const ccl_coll_attr_t* attr, -// ccl_comm_t comm, -// ccl_stream_t stream, -// ccl_request_t* req) -// { -// CCL_CHECK_IS_BLOCKED(); -// try -// { -// if (!req) -// { -// return ccl_status_invalid_arguments; -// } -// auto request = ccl_allreduce_impl(send_buf, recv_buf, count, dtype, static_cast(reduction), attr, -// (comm) ? static_cast(comm) : ccl::global_data::get().comm.get(), -// static_cast(stream)); -// *req = static_cast(request); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_alltoall( -// const void* send_buf, -// void* recv_buf, -// size_t count, -// ccl_datatype_t dtype, -// const ccl_coll_attr_t* attr, -// ccl_comm_t comm, -// ccl_stream_t stream, -// ccl_request_t* req) -// { -// CCL_CHECK_IS_BLOCKED(); -// try -// { -// if (!req) -// { -// return ccl_status_invalid_arguments; -// } -// auto request = ccl_alltoall_impl(send_buf, recv_buf, count, dtype, attr, -// (comm) ? static_cast(comm) : ccl::global_data::get().comm.get(), -// static_cast(stream)); -// *req = static_cast(request); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_alltoallv( -// const void* send_buf, -// const size_t* send_counts, -// void* recv_buf, -// const size_t* recv_counts, -// ccl_datatype_t dtype, -// const ccl_coll_attr_t* attr, -// ccl_comm_t comm, -// ccl_stream_t stream, -// ccl_request_t* req) -// { -// CCL_CHECK_IS_BLOCKED(); -// try -// { -// if (!req) -// { -// return ccl_status_invalid_arguments; -// } -// auto request = ccl_alltoallv_impl(send_buf, send_counts, recv_buf, recv_counts, dtype, attr, -// (comm) ? static_cast(comm) : ccl::global_data::get().comm.get(), -// static_cast(stream)); -// *req = static_cast(request); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_barrier(ccl_comm_t comm, ccl_stream_t stream) -// { -// try -// { -// ccl_barrier_impl((comm) ? static_cast(comm) : ccl::global_data::get().comm.get(), -// static_cast(stream)); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_bcast( -// void* buf, -// size_t count, -// ccl_datatype_t dtype, -// size_t root, -// const ccl_coll_attr_t* attr, -// ccl_comm_t comm, -// ccl_stream_t stream, -// ccl_request_t* req) -// { -// CCL_CHECK_IS_BLOCKED(); -// try -// { -// if (!req) -// { -// return ccl_status_invalid_arguments; -// } -// auto request = ccl_broadcast_impl(buf, count, dtype, root, attr, -// (comm) ? static_cast(comm) : ccl::global_data::get().comm.get(), -// static_cast(stream)); -// *req = static_cast(request); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_reduce( -// const void* send_buf, -// void* recv_buf, -// size_t count, -// ccl_datatype_t dtype, -// ccl_reduction_t reduction, -// size_t root, -// const ccl_coll_attr_t* attr, -// ccl_comm_t comm, -// ccl_stream_t stream, -// ccl_request_t* req) -// { -// CCL_CHECK_IS_BLOCKED(); -// try -// { -// if (!req) -// { -// return ccl_status_invalid_arguments; -// } -// auto request = ccl_reduce_impl(send_buf, recv_buf, count, dtype, static_cast(reduction), root, attr, -// (comm) ? static_cast(comm) : ccl::global_data::get().comm.get(), -// static_cast(stream)); -// *req = static_cast(request); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } - -// ccl_status_t CCL_API ccl_sparse_allreduce(const void* send_ind_buf, size_t send_ind_count, -// const void* send_val_buf, size_t send_val_count, -// void* recv_ind_buf, size_t recv_ind_count, -// void* recv_val_buf, size_t recv_val_count, -// ccl_datatype_t index_dtype, -// ccl_datatype_t value_dtype, -// ccl_reduction_t reduction, -// const ccl_coll_attr_t* attr, -// ccl_comm_t comm, -// ccl_stream_t stream, -// ccl_request_t* req) -// { -// CCL_CHECK_IS_BLOCKED(); -// try -// { -// if (!req) -// { -// return ccl_status_invalid_arguments; -// } -// auto request = ccl_sparse_allreduce_impl(send_ind_buf, send_ind_count, -// send_val_buf, send_val_count, -// recv_ind_buf, recv_ind_count, -// recv_val_buf, recv_val_count, -// index_dtype, value_dtype, -// static_cast(reduction), attr, -// (comm) ? static_cast(comm) : ccl::global_data::get().comm.get(), -// static_cast(stream)); -// *req = static_cast(request); -// return ccl_status_success; -// } -// COMMON_CATCH_BLOCK(); -// } diff --git a/src/ccl_api_functions.cpp b/src/ccl_api_functions.cpp index 718084907..16e97688f 100644 --- a/src/ccl_api_functions.cpp +++ b/src/ccl_api_functions.cpp @@ -13,1243 +13,1316 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_environment.hpp" -#include "oneapi/ccl/ccl_api_functions.hpp" -#include "common/comm/host_communicator/host_communicator.hpp" - -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -#include "common/comm/comm_interface.hpp" -#endif //#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) - -#include "ccl_api_functions_generators.hpp" -#include "common/global/global.hpp" -#include "ccl_gpu_module.hpp" - -namespace ccl { - -/** - * A structure that is a friend of the passed object - * and which allows access to the internal representation of this object - */ -struct impl_dispatch { - template - const typename Object::impl_value_t& operator()(const Object& obj) { - return obj.get_impl(); - } -}; - -#ifdef MULTI_GPU_SUPPORT -/* register a gpu module */ -void register_gpu_module(std::string kernel_dir_path) -{ - // allgatherv - if (!kernel_dir_path.empty()) - { - if(*kernel_dir_path.rbegin() != '/') - { - kernel_dir_path += '/'; - } - } - LOG_INFO("SPV Kernels found directory: ", kernel_dir_path); - std::string kernel_path = kernel_dir_path + "ring_allgatherv.spv"; - register_gpu_module_source(kernel_path.c_str(), - ccl::device_topology_type::ring, - ccl_coll_allgatherv); - // register__gpu_module_source("kernels/a2a_allgatherv.spv", - // ccl::device_topology_type::a2a, - // ccl_coll_allgatherv); - // alltoallv - kernel_path = kernel_dir_path + "ring_alltoallv.spv"; - register_gpu_module_source(kernel_path.c_str(), - ccl::device_topology_type::ring, - ccl_coll_alltoallv); - // register_gpu_module_source("kernels/a2a_alltoallv.spv", - // ccl::device_topology_type::a2a, - // ccl_coll_alltoallv); - // allreduce - kernel_path = kernel_dir_path + "ring_allreduce.spv"; - register_gpu_module_source(kernel_path.c_str(), - ccl::device_topology_type::ring, - ccl_coll_allreduce); - kernel_path = kernel_dir_path + "a2a_allreduce.spv"; - register_gpu_module_source(kernel_path.c_str(), - ccl::device_topology_type::a2a, - ccl_coll_allreduce); - // bcast - kernel_path = kernel_dir_path + "ring_bcast.spv"; - register_gpu_module_source(kernel_path.c_str(), - ccl::device_topology_type::ring, - ccl_coll_bcast); - kernel_path = kernel_dir_path + "a2a_bcast.spv"; - register_gpu_module_source(kernel_path.c_str(), - ccl::device_topology_type::a2a, - ccl_coll_bcast); - // reduce - kernel_path = kernel_dir_path + "ring_reduce.spv"; - register_gpu_module_source(kernel_path.c_str(), - ccl::device_topology_type::ring, - ccl_coll_reduce); - // register_gpu_module_source("kernels/a2a_reduce.spv", - // ccl_topology_class_t::a2a_algo_class, - // ccl_coll_reduce); -} -#endif //MULTI_GPU_SUPPORT - -void CCL_API init() { - auto& env = environment::instance(); - (void)env; -#ifdef MULTI_GPU_SUPPORT - const auto& env_object = ccl::global_data::env(); - - //WA - if (!env_object.kernel_path.empty()) - { - register_gpu_module(env_object.kernel_path); - } -#endif //MULTI_GPU_SUPPORT -} - -/******************** ENVIRONMENT ********************/ - -library_version CCL_API get_library_version() { - return environment::instance().get_library_version(); -} - -/* datatype */ -datatype CCL_API register_datatype(const datatype_attr& attr) { - return environment::instance().register_datatype(attr); -} - -void CCL_API deregister_datatype(datatype dtype) { - return environment::instance().deregister_datatype(dtype); -} - -size_t CCL_API get_datatype_size(datatype dtype) { - return environment::instance().get_datatype_size(dtype); -} - -/* KVS */ -shared_ptr_class CCL_API create_main_kvs() { - return environment::instance().create_main_kvs(); -} - -shared_ptr_class CCL_API create_kvs(const kvs::address_type& addr) { - return environment::instance().create_kvs(addr); -} - -/* device */ -device CCL_API create_device() -{ - static empty_t empty {}; - return environment::instance().create_device(empty); -} - -/* context */ -context CCL_API create_context() -{ - static empty_t empty {}; - return environment::instance().create_context(empty); -} - -/* stream */ -stream CCL_API create_stream() -{ - return default_stream; -} - -#ifdef CCL_ENABLE_SYCL -communicator create_single_device_communicator(const size_t comm_size, - const size_t rank, - const cl::sycl::device& device, - const cl::sycl::context& context, - shared_ptr_class kvs) { - return environment::instance().create_single_device_communicator( - comm_size, rank, device, context, kvs); -} -#endif // CCL_ENABLE_SYCL - -// communicator create_single_device_communicator(const size_t world_size, -// const size_t rank, -// cl::sycl::queue queue, -// shared_ptr_class kvs) const; - -// template -// communicator create_single_device_communicator(const size_t world_size, -// const size_t rank, -// const DeviceSelectorType& selector, -// shared_ptr_class kvs) const -// { -// return return environment::instance().create_single_device_communicator(world_size, rank, cl::sycl::device(selector), kvs); -// } - -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) - -vector_class split_device_communicators( - const vector_class>& attrs) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - - // return environment::instance().split_device_communicators(attrs); - return {}; -} - -#endif //#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) - -namespace preview { - -/* communicator */ -communicator CCL_API create_communicator() { - return environment::instance().create_communicator(); -} - -communicator CCL_API create_communicator(const size_t size, shared_ptr_class kvs) { - return environment::instance().create_communicator(size, kvs); -} - -} // namespace preview - -communicator CCL_API create_communicator(const size_t size, - const size_t rank, - shared_ptr_class kvs) { - return environment::instance().create_communicator(size, rank, kvs); -} - -/******************** COMMUNICATOR ********************/ - -/* allgatherv */ -CCL_API event allgatherv(const void* send_buf, - size_t send_count, - void* recv_buf, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const stream& op_stream, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_buf, recv_counts, dtype, disp(op_stream), attr, deps); -} - -CCL_API event allgatherv(const void* send_buf, - size_t send_count, - void* recv_buf, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_buf, recv_counts, dtype, disp(default_stream), attr, deps); -} - -CCL_API event allgatherv(const void* send_buf, - size_t send_count, - const vector_class& recv_bufs, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const stream& op_stream, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_bufs, recv_counts, dtype, disp(op_stream), attr, deps); -} - -CCL_API event allgatherv(const void* send_buf, - size_t send_count, - const vector_class& recv_bufs, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_bufs, recv_counts, dtype, disp(default_stream), attr, deps); -} - -template -event allgatherv(const BufferType* send_buf, - size_t send_count, - BufferType* recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const stream& op_stream, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_buf, recv_counts, disp(op_stream), attr, deps); -} - -template -event allgatherv(const BufferType* send_buf, - size_t send_count, - BufferType* recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_buf, recv_counts, disp(default_stream), attr, deps); -} - -template -event allgatherv(const BufferType* send_buf, - size_t send_count, - vector_class& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const stream& op_stream, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_bufs, recv_counts, disp(op_stream), attr, deps); -} - -template -event allgatherv(const BufferType* send_buf, - size_t send_count, - vector_class& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_bufs, recv_counts, disp(default_stream), attr, deps); -} - -template -event allgatherv(const BufferObjectType& send_buf, - size_t send_count, - BufferObjectType& recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const stream& op_stream, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_buf, recv_counts, disp(op_stream), attr, deps); -} - -template -event allgatherv(const BufferObjectType& send_buf, - size_t send_count, - BufferObjectType& recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_buf, recv_counts, disp(default_stream), attr, deps); -} - -template -event allgatherv(const BufferObjectType& send_buf, - size_t send_count, - vector_class>& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const stream& op_stream, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_bufs, recv_counts, disp(op_stream), attr, deps); -} - -template -event allgatherv(const BufferObjectType& send_buf, - size_t send_count, - vector_class>& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const allgatherv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allgatherv( - send_buf, send_count, recv_bufs, recv_counts, disp(default_stream), attr, deps); -} - -/* allreduce */ -CCL_API event allreduce(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - reduction reduction, - const communicator& comm, - const stream& op_stream, - const allreduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allreduce( - send_buf, recv_buf, count, dtype, reduction, disp(op_stream), attr, deps); -} - -CCL_API event allreduce(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - reduction reduction, - const communicator& comm, - const allreduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allreduce( - send_buf, recv_buf, count, dtype, reduction, disp(default_stream), attr, deps); -} - -template -event allreduce(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - reduction reduction, - const communicator& comm, - const stream& op_stream, - const allreduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allreduce(send_buf, recv_buf, count, reduction, disp(op_stream), attr, deps); -} - -template -event allreduce(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - reduction reduction, - const communicator& comm, - const allreduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allreduce(send_buf, recv_buf, count, reduction, disp(default_stream), attr, deps); -} - -template -event allreduce(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - reduction reduction, - const communicator& comm, - const stream& op_stream, - const allreduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allreduce(send_buf, recv_buf, count, reduction, disp(op_stream), attr, deps); -} - -template -event allreduce(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - reduction reduction, - const communicator& comm, - const allreduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->allreduce(send_buf, recv_buf, count, reduction, disp(default_stream), attr, deps); -} - -/* alltoall */ -CCL_API event alltoall(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - const communicator& comm, - const stream& op_stream, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, dtype, disp(op_stream), attr, deps); -} - -CCL_API event alltoall(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - const communicator& comm, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, dtype, disp(default_stream), attr, deps); -} - -CCL_API event alltoall(const vector_class& send_buf, - const vector_class& recv_buf, - size_t count, - datatype dtype, - const communicator& comm, - const stream& op_stream, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, dtype, disp(op_stream), attr, deps); -} - -template -event alltoall(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - const communicator& comm, - const stream& op_stream, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, disp(op_stream), attr, deps); -} - -template -event alltoall(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - const communicator& comm, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, disp(default_stream), attr, deps); -} - -template -event alltoall(const vector_class& send_buf, - const vector_class& recv_buf, - size_t count, - const communicator& comm, - const stream& op_stream, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, disp(op_stream), attr, deps); -} - -template -event alltoall(const vector_class& send_buf, - const vector_class& recv_buf, - size_t count, - const communicator& comm, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, disp(default_stream), attr, deps); -} - -template -event alltoall(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - const communicator& comm, - const stream& op_stream, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, disp(op_stream), attr, deps); -} - -template -event alltoall(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - const communicator& comm, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, disp(default_stream), attr, deps); -} - -template -event alltoall(const vector_class>& send_buf, - const vector_class>& recv_buf, - size_t count, - const communicator& comm, - const stream& op_stream, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, disp(op_stream), attr, deps); -} - -template -event alltoall(const vector_class>& send_buf, - const vector_class>& recv_buf, - size_t count, - const communicator& comm, - const alltoall_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoall(send_buf, recv_buf, count, disp(default_stream), attr, deps); -} - -/* alltoallv */ -CCL_API event alltoallv(const void* send_buf, - const vector_class& send_counts, - void* recv_buf, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const stream& op_stream, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_buf, send_counts, recv_buf, recv_counts, dtype, disp(op_stream), attr, deps); -} - -CCL_API event alltoallv(const void* send_buf, - const vector_class& send_counts, - void* recv_buf, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_buf, send_counts, recv_buf, recv_counts, dtype, disp(default_stream), attr, deps); -} - -CCL_API event alltoallv(const vector_class& send_bufs, - const vector_class& send_counts, - const vector_class& recv_bufs, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const stream& op_stream, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_bufs, send_counts, recv_bufs, recv_counts, dtype, disp(op_stream), attr, deps); -} - -CCL_API event alltoallv(const vector_class& send_bufs, - const vector_class& send_counts, - const vector_class& recv_bufs, - const vector_class& recv_counts, - datatype dtype, - const communicator& comm, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_bufs, send_counts, recv_bufs, recv_counts, dtype, disp(default_stream), attr, deps); -} - -template -event alltoallv(const BufferType* send_buf, - const vector_class& send_counts, - BufferType* recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const stream& op_stream, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_buf, send_counts, recv_buf, recv_counts, disp(op_stream), attr, deps); -} - -template -event alltoallv(const BufferType* send_buf, - const vector_class& send_counts, - BufferType* recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_buf, send_counts, recv_buf, recv_counts, disp(default_stream), attr, deps); -} - -template -event alltoallv(const vector_class& send_bufs, - const vector_class& send_counts, - const vector_class& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const stream& op_stream, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_bufs, send_counts, recv_bufs, recv_counts, disp(op_stream), attr, deps); -} - -template -event alltoallv(const vector_class& send_bufs, - const vector_class& send_counts, - const vector_class& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_bufs, send_counts, recv_bufs, recv_counts, disp(default_stream), attr, deps); -} - -template -event alltoallv(const BufferObjectType& send_buf, - const vector_class& send_counts, - BufferObjectType& recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const stream& op_stream, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_buf, send_counts, recv_buf, recv_counts, disp(op_stream), attr, deps); -} - -template -event alltoallv(const BufferObjectType& send_buf, - const vector_class& send_counts, - BufferObjectType& recv_buf, - const vector_class& recv_counts, - const communicator& comm, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_buf, send_counts, recv_buf, recv_counts, disp(default_stream), attr, deps); -} - -template -event alltoallv(const vector_class>& send_bufs, - const vector_class& send_counts, - const vector_class>& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const stream& op_stream, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_bufs, send_counts, recv_bufs, recv_counts, disp(op_stream), attr, deps); -} - -template -event alltoallv(const vector_class>& send_bufs, - const vector_class& send_counts, - const vector_class>& recv_bufs, - const vector_class& recv_counts, - const communicator& comm, - const alltoallv_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->alltoallv( - send_bufs, send_counts, recv_bufs, recv_counts, disp(default_stream), attr, deps); -} - -/* barrier */ -CCL_API event barrier(const communicator& comm, - const stream& op_stream, - const barrier_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->barrier(disp(op_stream), attr, deps); -} - -CCL_API event barrier(const communicator& comm, - const barrier_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->barrier(disp(default_stream), attr, deps); -} - -/* broadcast */ -CCL_API event broadcast(void* buf, - size_t count, - datatype dtype, - size_t root, - const communicator& comm, - const stream& op_stream, - const broadcast_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->bcast(buf, count, dtype, root, disp(op_stream), attr, deps); -} - -CCL_API event broadcast(void* buf, - size_t count, - datatype dtype, - size_t root, - const communicator& comm, - const broadcast_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->bcast(buf, count, dtype, root, disp(default_stream), attr, deps); -} - -template -event broadcast(BufferType* buf, - size_t count, - size_t root, - const communicator& comm, - const stream& op_stream, - const broadcast_attr& attr, - const vector_class& deps) - -{ - impl_dispatch disp; - return disp(comm)->bcast(buf, count, root, disp(op_stream), attr, deps); -} - -template -event broadcast(BufferType* buf, - size_t count, - size_t root, - const communicator& comm, - const broadcast_attr& attr, - const vector_class& deps) - -{ - impl_dispatch disp; - return disp(comm)->bcast(buf, count, root, disp(default_stream), attr, deps); -} - -template -event broadcast(BufferObjectType& buf, - size_t count, - size_t root, - const communicator& comm, - const stream& op_stream, - const broadcast_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->bcast(buf, count, root, disp(op_stream), attr, deps); -} - -template -event broadcast(BufferObjectType& buf, - size_t count, - size_t root, - const communicator& comm, - const broadcast_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->bcast(buf, count, root, disp(default_stream), attr, deps); -} - -/* reduce */ -CCL_API event reduce(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - reduction reduction, - size_t root, - const communicator& comm, - const stream& op_stream, - const reduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce( - send_buf, recv_buf, count, dtype, reduction, root, disp(op_stream), attr, deps); -} - -CCL_API event reduce(const void* send_buf, - void* recv_buf, - size_t count, - datatype dtype, - reduction reduction, - size_t root, - const communicator& comm, - const reduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce( - send_buf, recv_buf, count, dtype, reduction, root, disp(default_stream), attr, deps); -} - -template -event reduce(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - reduction reduction, - size_t root, - const communicator& comm, - const stream& op_stream, - const reduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce( - send_buf, recv_buf, count, reduction, root, disp(op_stream), attr, deps); -} - -template -event reduce(const BufferType* send_buf, - BufferType* recv_buf, - size_t count, - reduction reduction, - size_t root, - const communicator& comm, - const reduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce( - send_buf, recv_buf, count, reduction, root, disp(default_stream), attr, deps); -} - -template -event reduce(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - reduction reduction, - size_t root, - const communicator& comm, - const stream& op_stream, - const reduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce( - send_buf, recv_buf, count, reduction, root, disp(op_stream), attr, deps); -} - -template -event reduce(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t count, - reduction reduction, - size_t root, - const communicator& comm, - const reduce_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce( - send_buf, recv_buf, count, reduction, root, disp(default_stream), attr, deps); -} - -/* reduce_scatter */ -CCL_API event reduce_scatter(const void* send_buf, - void* recv_buf, - size_t recv_count, - datatype dtype, - reduction reduction, - const communicator& comm, - const stream& op_stream, - const reduce_scatter_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce_scatter( - send_buf, recv_buf, recv_count, dtype, reduction, disp(op_stream), attr, deps); -} - -CCL_API event reduce_scatter(const void* send_buf, - void* recv_buf, - size_t recv_count, - datatype dtype, - reduction reduction, - const communicator& comm, - const reduce_scatter_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce_scatter( - send_buf, recv_buf, recv_count, dtype, reduction, disp(default_stream), attr, deps); -} - -template -event reduce_scatter(const BufferType* send_buf, - BufferType* recv_buf, - size_t recv_count, - reduction reduction, - const communicator& comm, - const stream& op_stream, - const reduce_scatter_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce_scatter( - send_buf, recv_buf, recv_count, reduction, disp(op_stream), attr, deps); -} - -template -event reduce_scatter(const BufferType* send_buf, - BufferType* recv_buf, - size_t recv_count, - reduction reduction, - const communicator& comm, - const reduce_scatter_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce_scatter( - send_buf, recv_buf, recv_count, reduction, disp(default_stream), attr, deps); -} - -template -event reduce_scatter(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t recv_count, - reduction reduction, - const communicator& comm, - const stream& op_stream, - const reduce_scatter_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce_scatter( - send_buf, recv_buf, recv_count, reduction, disp(op_stream), attr, deps); -} - -template -event reduce_scatter(const BufferObjectType& send_buf, - BufferObjectType& recv_buf, - size_t recv_count, - reduction reduction, - const communicator& comm, - const reduce_scatter_attr& attr, - const vector_class& deps) { - impl_dispatch disp; - return disp(comm)->reduce_scatter( - send_buf, recv_buf, recv_count, reduction, disp(default_stream), attr, deps); -} - -namespace preview { - -/* sparse_allreduce */ -CCL_API ccl::event sparse_allreduce(const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - ccl::reduction reduction, - const ccl::communicator& comm, - const ccl::stream& op_stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - ccl::impl_dispatch disp; - return disp(comm)->sparse_allreduce(send_ind_buf, - send_ind_count, - send_val_buf, - send_val_count, - recv_ind_buf, - recv_ind_count, - recv_val_buf, - recv_val_count, - index_dtype, - value_dtype, - reduction, - disp(op_stream), - attr, - deps); -} - -CCL_API ccl::event sparse_allreduce(const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - ccl::reduction reduction, - const ccl::communicator& comm, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - ccl::impl_dispatch disp; - return disp(comm)->sparse_allreduce(send_ind_buf, - send_ind_count, - send_val_buf, - send_val_count, - recv_ind_buf, - recv_ind_count, - recv_val_buf, - recv_val_count, - index_dtype, - value_dtype, - reduction, - disp(default_stream), - attr, - deps); -} - -template -ccl::event sparse_allreduce(const IndexBufferType* send_ind_buf, - size_t send_ind_count, - const ValueBufferType* send_val_buf, - size_t send_val_count, - IndexBufferType* recv_ind_buf, - size_t recv_ind_count, - ValueBufferType* recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::communicator& comm, - const ccl::stream& op_stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - ccl::impl_dispatch disp; - return disp(comm)->sparse_allreduce(send_ind_buf, - send_ind_count, - send_val_buf, - send_val_count, - recv_ind_buf, - recv_ind_count, - recv_val_buf, - recv_val_count, - reduction, - disp(op_stream), - attr, - deps); -} - -template -ccl::event sparse_allreduce(const IndexBufferType* send_ind_buf, - size_t send_ind_count, - const ValueBufferType* send_val_buf, - size_t send_val_count, - IndexBufferType* recv_ind_buf, - size_t recv_ind_count, - ValueBufferType* recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::communicator& comm, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - ccl::impl_dispatch disp; - return disp(comm)->sparse_allreduce(send_ind_buf, - send_ind_count, - send_val_buf, - send_val_count, - recv_ind_buf, - recv_ind_count, - recv_val_buf, - recv_val_count, - reduction, - disp(default_stream), - attr, - deps); -} - -// template -// ccl::event -// sparse_allreduce(const IndexBufferObjectType& send_ind_buf, -// size_t send_ind_count, -// const ValueBufferObjectType& send_val_buf, -// size_t send_val_count, -// IndexBufferObjectType& recv_ind_buf, -// size_t recv_ind_count, -// ValueBufferObjectType& recv_val_buf, -// size_t recv_val_count, -// ccl::reduction reduction, -// const ccl::communicator& comm, -// const ccl::stream& op_stream, -// const ccl::sparse_allreduce_attr& attr, -// const ccl::vector_class& deps) -// { -// ccl::impl_dispatch disp; -// return disp(comm)->sparse_allreduce(send_ind_buf, send_ind_count, -// send_val_buf, send_val_count, -// recv_ind_buf, recv_ind_count, -// recv_val_buf, recv_val_count, -// reduction, -// disp(op_stream), attr, deps); -// } -// -// template -// ccl::event -// sparse_allreduce(const IndexBufferObjectType& send_ind_buf, -// size_t send_ind_count, -// const ValueBufferObjectType& send_val_buf, -// size_t send_val_count, -// IndexBufferObjectType& recv_ind_buf, -// size_t recv_ind_count, -// ValueBufferObjectType& recv_val_buf, -// size_t recv_val_count, -// ccl::reduction reduction, -// const ccl::communicator& comm, -// const ccl::sparse_allreduce_attr& attr, -// const ccl::vector_class& deps) -// { -// ccl::impl_dispatch disp; -// return disp(comm)->sparse_allreduce(send_ind_buf, send_ind_count, -// send_val_buf, send_val_count, -// recv_ind_buf, recv_ind_count, -// recv_val_buf, recv_val_count, -// reduction, -// disp(default_stream), attr, deps); -// } - -} // namespace preview - -// API force instantiations for Operations -API_DEVICE_COMM_OP_PTR_EXPLICIT_INSTANTIATION(char); -API_DEVICE_COMM_OP_PTR_EXPLICIT_INSTANTIATION(int); -API_DEVICE_COMM_OP_PTR_EXPLICIT_INSTANTIATION(int64_t); -API_DEVICE_COMM_OP_PTR_EXPLICIT_INSTANTIATION(uint64_t); -API_DEVICE_COMM_OP_PTR_EXPLICIT_INSTANTIATION(float); -API_DEVICE_COMM_OP_PTR_EXPLICIT_INSTANTIATION(double); - -#ifdef CCL_ENABLE_SYCL -#ifndef COMMA -#define COMMA , -#endif -API_DEVICE_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); -API_DEVICE_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); -API_DEVICE_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); -API_DEVICE_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); -API_DEVICE_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); -API_DEVICE_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); -#undef COMMA -#endif // CCL_ENABLE_SYCL - -namespace preview { - -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(char, char); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(char, int); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(char, ccl::bf16); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(char, float); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(char, double); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(char, int64_t); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(char, uint64_t); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int, char); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int, int); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int, ccl::bf16); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int, float); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int, double); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int, int64_t); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int, uint64_t); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int64_t, char); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int64_t, int); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int64_t, ccl::bf16); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int64_t, float); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int64_t, double); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int64_t, int64_t); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int64_t, uint64_t); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(uint64_t, char); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(uint64_t, int); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(uint64_t, ccl::bf16); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(uint64_t, float); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(uint64_t, double); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(uint64_t, int64_t); -API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(uint64_t, uint64_t); - -// #ifdef CCL_ENABLE_SYCL -// #ifndef COMMA -// #define COMMA , -// #endif -// API_DEVICE_COMM_SPARSE_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer, -// cl::sycl::buffer); -// API_DEVICE_COMM_SPARSE_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer, -// cl::sycl::buffer); - -// API_DEVICE_COMM_SPARSE_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer, -// cl::sycl::buffer); -// API_DEVICE_COMM_SPARSE_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer, -// cl::sycl::buffer); -// #undef COMMA -// #endif //CCL_ENABLE_SYCL - -} // namespace preview - -} // namespace ccl +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/environment.hpp" +#include "oneapi/ccl/api_functions.hpp" +#include "common/comm/host_communicator/host_communicator.hpp" +#include "oneapi/ccl/exception.hpp" + +#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) +#include "common/comm/comm_interface.hpp" +#endif //#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) + +#include "ccl_api_functions_generators.hpp" +#include "common/global/global.hpp" +#include "ccl_gpu_module.hpp" + +namespace ccl { + +namespace v1 { + +/** + * A structure that is a friend of the passed object + * and which allows access to the internal representation of this object + */ +struct impl_dispatch { + template + const typename Object::impl_value_t& operator()(const Object& obj) { + return obj.get_impl(); + } +}; + +#ifdef MULTI_GPU_SUPPORT +/* register a gpu module */ +void register_gpu_module(std::string kernel_dir_path) { + if (!kernel_dir_path.empty()) { + if (*kernel_dir_path.rbegin() != '/') { + kernel_dir_path += '/'; + } + } + LOG_INFO("SPV Kernels found directory: ", kernel_dir_path); + + /* + * TODO: + * Important: Fix kernels data types generations, then uncoment + * the registration module. + */ + + // allgatherv + std::string kernel_path = kernel_dir_path + "ring_allgatherv.spv"; + register_gpu_module_source( + kernel_path.c_str(), ccl::device_topology_type::ring, ccl_coll_allgatherv); + // kernel_path = kernel_dir_path + "a2a_allgatherv.spv"; + // register_gpu_module_source(kernel_path.c_str(), + // ccl::device_topology_type::a2a, + // ccl_coll_allgatherv); + // alltoallv + kernel_path = kernel_dir_path + "ring_alltoallv.spv"; + register_gpu_module_source( + kernel_path.c_str(), ccl::device_topology_type::ring, ccl_coll_alltoallv); + // register_gpu_module_source("kernels/a2a_alltoallv.spv", + // ccl::device_topology_type::a2a, + // ccl_coll_alltoallv); + // allreduce + kernel_path = kernel_dir_path + "ring_allreduce.spv"; + register_gpu_module_source( + kernel_path.c_str(), ccl::device_topology_type::ring, ccl_coll_allreduce); + // kernel_path = kernel_dir_path + "a2a_allreduce.spv"; + // register_gpu_module_source(kernel_path.c_str(), + // ccl::device_topology_type::a2a, + // ccl_coll_allreduce); + // // bcast + kernel_path = kernel_dir_path + "ring_bcast.spv"; + register_gpu_module_source( + kernel_path.c_str(), ccl::device_topology_type::ring, ccl_coll_bcast); + // kernel_path = kernel_dir_path + "a2a_bcast.spv"; + // register_gpu_module_source(kernel_path.c_str(), + // ccl::device_topology_type::a2a, + // ccl_coll_bcast); + kernel_path = kernel_dir_path + "ring_reduce.spv"; + register_gpu_module_source( + kernel_path.c_str(), ccl::device_topology_type::ring, ccl_coll_reduce); + // kernel_path = kernel_dir_path + "a2a_reduce.spv"; + // register_gpu_module_source(kernel_path.c_str(), + // ccl::device_topology_type::a2a, + // ccl_coll_reduce); +} +#endif //MULTI_GPU_SUPPORT + +void CCL_API init(const init_attr& attr) { + auto& env = detail::environment::instance(); + (void)env; +#ifdef MULTI_GPU_SUPPORT + const auto& env_object = ccl::global_data::env(); + + //WA + if (!env_object.kernel_path.empty()) { + register_gpu_module(env_object.kernel_path); + } +#endif //MULTI_GPU_SUPPORT +} + +/******************** ENVIRONMENT ********************/ + +library_version CCL_API get_library_version() { + return detail::environment::get_library_version(); +} + +/* datatype */ +datatype CCL_API register_datatype(const datatype_attr& attr) { + return detail::environment::instance().register_datatype(attr); +} + +void CCL_API deregister_datatype(datatype dtype) { + return detail::environment::instance().deregister_datatype(dtype); +} + +size_t CCL_API get_datatype_size(datatype dtype) { + return detail::environment::instance().get_datatype_size(dtype); +} + +/* KVS */ +shared_ptr_class CCL_API create_main_kvs(const kvs_attr& attr) { + return detail::environment::instance().create_main_kvs(attr); +} + +shared_ptr_class CCL_API create_kvs(const kvs::address_type& addr, const kvs_attr& attr) { + return detail::environment::instance().create_kvs(addr, attr); +} + +/* device */ +device CCL_API create_device() { + static empty_t empty{}; + return detail::environment::instance().create_device(empty); +} + +/* context */ +context CCL_API create_context() { + static empty_t empty{}; + return detail::environment::instance().create_context(empty); +} + +/* stream */ +stream CCL_API create_stream() { + return default_stream; +} + +#ifdef CCL_ENABLE_SYCL +communicator create_single_device_communicator(const int comm_size, + const int rank, + const cl::sycl::device& device, + const cl::sycl::context& context, + shared_ptr_class kvs) { + return detail::environment::instance().create_single_device_communicator( + comm_size, rank, device, context, kvs); +} +#endif // CCL_ENABLE_SYCL + +// communicator create_single_device_communicator(const size_t world_size, +// const int rank, +// cl::sycl::queue queue, +// shared_ptr_class kvs) const; + +// template +// communicator create_single_device_communicator(const size_t world_size, +// const int rank, +// const DeviceSelectorType& selector, +// shared_ptr_class kvs) const +// { +// return return detail::environment::instance().create_single_device_communicator(world_size, rank, cl::sycl::device(selector), kvs); +// } + +} // namespace v1 + +namespace preview { + +vector_class split_communicators( + const vector_class>& attrs) { + // TODO not implemented + throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); + + // return detail::environment::instance().split_device_communicators(attrs); + return {}; +} + +/* communicator */ +communicator CCL_API create_communicator(const comm_attr& attr) { + return ccl::detail::environment::instance().create_communicator(attr); +} + +communicator CCL_API create_communicator(const int size, + shared_ptr_class kvs, + const comm_attr& attr) { + return ccl::detail::environment::instance().create_communicator(size, kvs, attr); +} + +} // namespace preview + +namespace v1 { + +communicator CCL_API create_communicator(const int size, + const int rank, + shared_ptr_class kvs, + const comm_attr& attr) { + return detail::environment::instance().create_communicator(size, rank, kvs, attr); +} + +/******************** COMMUNICATOR ********************/ + +#define CHECK_DEPS(deps) \ + do { \ + if (!deps.empty()) { \ + throw ccl::exception( \ + std::string(__PRETTY_FUNCTION__) + \ + " - handling a vector of events that the operation should depend on is not implemented"); \ + } \ + } while (0) + +/* allgatherv */ +CCL_API event allgatherv(const void* send_buf, + size_t send_count, + void* recv_buf, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const stream& op_stream, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_buf, recv_counts, dtype, disp(op_stream), attr, deps); +} + +CCL_API event allgatherv(const void* send_buf, + size_t send_count, + void* recv_buf, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_buf, recv_counts, dtype, disp(default_stream), attr, deps); +} + +CCL_API event allgatherv(const void* send_buf, + size_t send_count, + const vector_class& recv_bufs, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const stream& op_stream, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_bufs, recv_counts, dtype, disp(op_stream), attr, deps); +} + +CCL_API event allgatherv(const void* send_buf, + size_t send_count, + const vector_class& recv_bufs, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_bufs, recv_counts, dtype, disp(default_stream), attr, deps); +} + +template +event allgatherv(const BufferType* send_buf, + size_t send_count, + BufferType* recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const stream& op_stream, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_buf, recv_counts, disp(op_stream), attr, deps); +} + +template +event allgatherv(const BufferType* send_buf, + size_t send_count, + BufferType* recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_buf, recv_counts, disp(default_stream), attr, deps); +} + +template +event allgatherv(const BufferType* send_buf, + size_t send_count, + vector_class& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const stream& op_stream, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_bufs, recv_counts, disp(op_stream), attr, deps); +} + +template +event allgatherv(const BufferType* send_buf, + size_t send_count, + vector_class& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_bufs, recv_counts, disp(default_stream), attr, deps); +} + +template +event allgatherv(const BufferObjectType& send_buf, + size_t send_count, + BufferObjectType& recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const stream& op_stream, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_buf, recv_counts, disp(op_stream), attr, deps); +} + +template +event allgatherv(const BufferObjectType& send_buf, + size_t send_count, + BufferObjectType& recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_buf, recv_counts, disp(default_stream), attr, deps); +} + +template +event allgatherv(const BufferObjectType& send_buf, + size_t send_count, + vector_class>& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const stream& op_stream, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_bufs, recv_counts, disp(op_stream), attr, deps); +} + +template +event allgatherv(const BufferObjectType& send_buf, + size_t send_count, + vector_class>& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const allgatherv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allgatherv( + send_buf, send_count, recv_bufs, recv_counts, disp(default_stream), attr, deps); +} + +/* allreduce */ +CCL_API event allreduce(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + reduction reduction, + const communicator& comm, + const stream& op_stream, + const allreduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allreduce( + send_buf, recv_buf, count, dtype, reduction, disp(op_stream), attr, deps); +} + +CCL_API event allreduce(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + reduction reduction, + const communicator& comm, + const allreduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allreduce( + send_buf, recv_buf, count, dtype, reduction, disp(default_stream), attr, deps); +} + +template +event allreduce(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + reduction reduction, + const communicator& comm, + const stream& op_stream, + const allreduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allreduce(send_buf, recv_buf, count, reduction, disp(op_stream), attr, deps); +} + +template +event allreduce(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + reduction reduction, + const communicator& comm, + const allreduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allreduce( + send_buf, recv_buf, count, reduction, disp(default_stream), attr, deps); +} + +template +event allreduce(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + reduction reduction, + const communicator& comm, + const stream& op_stream, + const allreduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allreduce(send_buf, recv_buf, count, reduction, disp(op_stream), attr, deps); +} + +template +event allreduce(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + reduction reduction, + const communicator& comm, + const allreduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->allreduce( + send_buf, recv_buf, count, reduction, disp(default_stream), attr, deps); +} + +/* alltoall */ +CCL_API event alltoall(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + const communicator& comm, + const stream& op_stream, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, dtype, disp(op_stream), attr, deps); +} + +CCL_API event alltoall(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + const communicator& comm, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, dtype, disp(default_stream), attr, deps); +} + +CCL_API event alltoall(const vector_class& send_buf, + const vector_class& recv_buf, + size_t count, + datatype dtype, + const communicator& comm, + const stream& op_stream, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, dtype, disp(op_stream), attr, deps); +} + +template +event alltoall(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + const communicator& comm, + const stream& op_stream, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, disp(op_stream), attr, deps); +} + +template +event alltoall(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + const communicator& comm, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, disp(default_stream), attr, deps); +} + +template +event alltoall(const vector_class& send_buf, + const vector_class& recv_buf, + size_t count, + const communicator& comm, + const stream& op_stream, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, disp(op_stream), attr, deps); +} + +template +event alltoall(const vector_class& send_buf, + const vector_class& recv_buf, + size_t count, + const communicator& comm, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, disp(default_stream), attr, deps); +} + +template +event alltoall(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + const communicator& comm, + const stream& op_stream, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, disp(op_stream), attr, deps); +} + +template +event alltoall(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + const communicator& comm, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, disp(default_stream), attr, deps); +} + +template +event alltoall(const vector_class>& send_buf, + const vector_class>& recv_buf, + size_t count, + const communicator& comm, + const stream& op_stream, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, disp(op_stream), attr, deps); +} + +template +event alltoall(const vector_class>& send_buf, + const vector_class>& recv_buf, + size_t count, + const communicator& comm, + const alltoall_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoall(send_buf, recv_buf, count, disp(default_stream), attr, deps); +} + +/* alltoallv */ +CCL_API event alltoallv(const void* send_buf, + const vector_class& send_counts, + void* recv_buf, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const stream& op_stream, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_buf, send_counts, recv_buf, recv_counts, dtype, disp(op_stream), attr, deps); +} + +CCL_API event alltoallv(const void* send_buf, + const vector_class& send_counts, + void* recv_buf, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_buf, send_counts, recv_buf, recv_counts, dtype, disp(default_stream), attr, deps); +} + +CCL_API event alltoallv(const vector_class& send_bufs, + const vector_class& send_counts, + const vector_class& recv_bufs, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const stream& op_stream, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_bufs, send_counts, recv_bufs, recv_counts, dtype, disp(op_stream), attr, deps); +} + +CCL_API event alltoallv(const vector_class& send_bufs, + const vector_class& send_counts, + const vector_class& recv_bufs, + const vector_class& recv_counts, + datatype dtype, + const communicator& comm, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_bufs, send_counts, recv_bufs, recv_counts, dtype, disp(default_stream), attr, deps); +} + +template +event alltoallv(const BufferType* send_buf, + const vector_class& send_counts, + BufferType* recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const stream& op_stream, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_buf, send_counts, recv_buf, recv_counts, disp(op_stream), attr, deps); +} + +template +event alltoallv(const BufferType* send_buf, + const vector_class& send_counts, + BufferType* recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_buf, send_counts, recv_buf, recv_counts, disp(default_stream), attr, deps); +} + +template +event alltoallv(const vector_class& send_bufs, + const vector_class& send_counts, + const vector_class& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const stream& op_stream, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_bufs, send_counts, recv_bufs, recv_counts, disp(op_stream), attr, deps); +} + +template +event alltoallv(const vector_class& send_bufs, + const vector_class& send_counts, + const vector_class& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_bufs, send_counts, recv_bufs, recv_counts, disp(default_stream), attr, deps); +} + +template +event alltoallv(const BufferObjectType& send_buf, + const vector_class& send_counts, + BufferObjectType& recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const stream& op_stream, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_buf, send_counts, recv_buf, recv_counts, disp(op_stream), attr, deps); +} + +template +event alltoallv(const BufferObjectType& send_buf, + const vector_class& send_counts, + BufferObjectType& recv_buf, + const vector_class& recv_counts, + const communicator& comm, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_buf, send_counts, recv_buf, recv_counts, disp(default_stream), attr, deps); +} + +template +event alltoallv(const vector_class>& send_bufs, + const vector_class& send_counts, + const vector_class>& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const stream& op_stream, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_bufs, send_counts, recv_bufs, recv_counts, disp(op_stream), attr, deps); +} + +template +event alltoallv(const vector_class>& send_bufs, + const vector_class& send_counts, + const vector_class>& recv_bufs, + const vector_class& recv_counts, + const communicator& comm, + const alltoallv_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->alltoallv( + send_bufs, send_counts, recv_bufs, recv_counts, disp(default_stream), attr, deps); +} + +/* barrier */ +CCL_API event barrier(const communicator& comm, + const stream& op_stream, + const barrier_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->barrier(disp(op_stream), attr, deps); +} + +CCL_API event barrier(const communicator& comm, + const barrier_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->barrier(disp(default_stream), attr, deps); +} + +/* broadcast */ +CCL_API event broadcast(void* buf, + size_t count, + datatype dtype, + int root, + const communicator& comm, + const stream& op_stream, + const broadcast_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->bcast(buf, count, dtype, root, disp(op_stream), attr, deps); +} + +CCL_API event broadcast(void* buf, + size_t count, + datatype dtype, + int root, + const communicator& comm, + const broadcast_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->bcast(buf, count, dtype, root, disp(default_stream), attr, deps); +} + +template +event broadcast(BufferType* buf, + size_t count, + int root, + const communicator& comm, + const stream& op_stream, + const broadcast_attr& attr, + const vector_class& deps) + +{ + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->bcast(buf, count, root, disp(op_stream), attr, deps); +} + +template +event broadcast(BufferType* buf, + size_t count, + int root, + const communicator& comm, + const broadcast_attr& attr, + const vector_class& deps) + +{ + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->bcast(buf, count, root, disp(default_stream), attr, deps); +} + +template +event broadcast(BufferObjectType& buf, + size_t count, + int root, + const communicator& comm, + const stream& op_stream, + const broadcast_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->bcast(buf, count, root, disp(op_stream), attr, deps); +} + +template +event broadcast(BufferObjectType& buf, + size_t count, + int root, + const communicator& comm, + const broadcast_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->bcast(buf, count, root, disp(default_stream), attr, deps); +} + +/* reduce */ +CCL_API event reduce(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + reduction reduction, + int root, + const communicator& comm, + const stream& op_stream, + const reduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce( + send_buf, recv_buf, count, dtype, reduction, root, disp(op_stream), attr, deps); +} + +CCL_API event reduce(const void* send_buf, + void* recv_buf, + size_t count, + datatype dtype, + reduction reduction, + int root, + const communicator& comm, + const reduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce( + send_buf, recv_buf, count, dtype, reduction, root, disp(default_stream), attr, deps); +} + +template +event reduce(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + reduction reduction, + int root, + const communicator& comm, + const stream& op_stream, + const reduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce( + send_buf, recv_buf, count, reduction, root, disp(op_stream), attr, deps); +} + +template +event reduce(const BufferType* send_buf, + BufferType* recv_buf, + size_t count, + reduction reduction, + int root, + const communicator& comm, + const reduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce( + send_buf, recv_buf, count, reduction, root, disp(default_stream), attr, deps); +} + +template +event reduce(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + reduction reduction, + int root, + const communicator& comm, + const stream& op_stream, + const reduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce( + send_buf, recv_buf, count, reduction, root, disp(op_stream), attr, deps); +} + +template +event reduce(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t count, + reduction reduction, + int root, + const communicator& comm, + const reduce_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce( + send_buf, recv_buf, count, reduction, root, disp(default_stream), attr, deps); +} + +/* reduce_scatter */ +CCL_API event reduce_scatter(const void* send_buf, + void* recv_buf, + size_t recv_count, + datatype dtype, + reduction reduction, + const communicator& comm, + const stream& op_stream, + const reduce_scatter_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce_scatter( + send_buf, recv_buf, recv_count, dtype, reduction, disp(op_stream), attr, deps); +} + +CCL_API event reduce_scatter(const void* send_buf, + void* recv_buf, + size_t recv_count, + datatype dtype, + reduction reduction, + const communicator& comm, + const reduce_scatter_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce_scatter( + send_buf, recv_buf, recv_count, dtype, reduction, disp(default_stream), attr, deps); +} + +template +event reduce_scatter(const BufferType* send_buf, + BufferType* recv_buf, + size_t recv_count, + reduction reduction, + const communicator& comm, + const stream& op_stream, + const reduce_scatter_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce_scatter( + send_buf, recv_buf, recv_count, reduction, disp(op_stream), attr, deps); +} + +template +event reduce_scatter(const BufferType* send_buf, + BufferType* recv_buf, + size_t recv_count, + reduction reduction, + const communicator& comm, + const reduce_scatter_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce_scatter( + send_buf, recv_buf, recv_count, reduction, disp(default_stream), attr, deps); +} + +template +event reduce_scatter(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t recv_count, + reduction reduction, + const communicator& comm, + const stream& op_stream, + const reduce_scatter_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce_scatter( + send_buf, recv_buf, recv_count, reduction, disp(op_stream), attr, deps); +} + +template +event reduce_scatter(const BufferObjectType& send_buf, + BufferObjectType& recv_buf, + size_t recv_count, + reduction reduction, + const communicator& comm, + const reduce_scatter_attr& attr, + const vector_class& deps) { + CHECK_DEPS(deps); + impl_dispatch disp; + return disp(comm)->reduce_scatter( + send_buf, recv_buf, recv_count, reduction, disp(default_stream), attr, deps); +} + +} // namespace v1 + +namespace preview { + +/* sparse_allreduce */ +CCL_API ccl::event sparse_allreduce(const void* send_ind_buf, + size_t send_ind_count, + const void* send_val_buf, + size_t send_val_count, + void* recv_ind_buf, + size_t recv_ind_count, + void* recv_val_buf, + size_t recv_val_count, + ccl::datatype index_dtype, + ccl::datatype value_dtype, + ccl::reduction reduction, + const ccl::communicator& comm, + const ccl::stream& op_stream, + const ccl::sparse_allreduce_attr& attr, + const ccl::vector_class& deps) { + CHECK_DEPS(deps); + ccl::impl_dispatch disp; + return disp(comm)->sparse_allreduce(send_ind_buf, + send_ind_count, + send_val_buf, + send_val_count, + recv_ind_buf, + recv_ind_count, + recv_val_buf, + recv_val_count, + index_dtype, + value_dtype, + reduction, + disp(op_stream), + attr, + deps); +} + +CCL_API ccl::event sparse_allreduce(const void* send_ind_buf, + size_t send_ind_count, + const void* send_val_buf, + size_t send_val_count, + void* recv_ind_buf, + size_t recv_ind_count, + void* recv_val_buf, + size_t recv_val_count, + ccl::datatype index_dtype, + ccl::datatype value_dtype, + ccl::reduction reduction, + const ccl::communicator& comm, + const ccl::sparse_allreduce_attr& attr, + const ccl::vector_class& deps) { + CHECK_DEPS(deps); + ccl::impl_dispatch disp; + return disp(comm)->sparse_allreduce(send_ind_buf, + send_ind_count, + send_val_buf, + send_val_count, + recv_ind_buf, + recv_ind_count, + recv_val_buf, + recv_val_count, + index_dtype, + value_dtype, + reduction, + disp(default_stream), + attr, + deps); +} + +template +ccl::event sparse_allreduce(const IndexBufferType* send_ind_buf, + size_t send_ind_count, + const ValueBufferType* send_val_buf, + size_t send_val_count, + IndexBufferType* recv_ind_buf, + size_t recv_ind_count, + ValueBufferType* recv_val_buf, + size_t recv_val_count, + ccl::reduction reduction, + const ccl::communicator& comm, + const ccl::stream& op_stream, + const ccl::sparse_allreduce_attr& attr, + const ccl::vector_class& deps) { + CHECK_DEPS(deps); + ccl::impl_dispatch disp; + return disp(comm)->sparse_allreduce(send_ind_buf, + send_ind_count, + send_val_buf, + send_val_count, + recv_ind_buf, + recv_ind_count, + recv_val_buf, + recv_val_count, + reduction, + disp(op_stream), + attr, + deps); +} + +template +ccl::event sparse_allreduce(const IndexBufferType* send_ind_buf, + size_t send_ind_count, + const ValueBufferType* send_val_buf, + size_t send_val_count, + IndexBufferType* recv_ind_buf, + size_t recv_ind_count, + ValueBufferType* recv_val_buf, + size_t recv_val_count, + ccl::reduction reduction, + const ccl::communicator& comm, + const ccl::sparse_allreduce_attr& attr, + const ccl::vector_class& deps) { + CHECK_DEPS(deps); + ccl::impl_dispatch disp; + return disp(comm)->sparse_allreduce(send_ind_buf, + send_ind_count, + send_val_buf, + send_val_count, + recv_ind_buf, + recv_ind_count, + recv_val_buf, + recv_val_count, + reduction, + disp(default_stream), + attr, + deps); +} + +// template +// ccl::event +// sparse_allreduce(const IndexBufferObjectType& send_ind_buf, +// size_t send_ind_count, +// const ValueBufferObjectType& send_val_buf, +// size_t send_val_count, +// IndexBufferObjectType& recv_ind_buf, +// size_t recv_ind_count, +// ValueBufferObjectType& recv_val_buf, +// size_t recv_val_count, +// ccl::reduction reduction, +// const ccl::communicator& comm, +// const ccl::stream& op_stream, +// const ccl::sparse_allreduce_attr& attr, +// const ccl::vector_class& deps) +// { +// CHECK_DEPS(deps); +// ccl::impl_dispatch disp; +// return disp(comm)->sparse_allreduce(send_ind_buf, send_ind_count, +// send_val_buf, send_val_count, +// recv_ind_buf, recv_ind_count, +// recv_val_buf, recv_val_count, +// reduction, +// disp(op_stream), attr, deps); +// } +// +// template +// ccl::event +// sparse_allreduce(const IndexBufferObjectType& send_ind_buf, +// size_t send_ind_count, +// const ValueBufferObjectType& send_val_buf, +// size_t send_val_count, +// IndexBufferObjectType& recv_ind_buf, +// size_t recv_ind_count, +// ValueBufferObjectType& recv_val_buf, +// size_t recv_val_count, +// ccl::reduction reduction, +// const ccl::communicator& comm, +// const ccl::sparse_allreduce_attr& attr, +// const ccl::vector_class& deps) +// { +// CHECK_DEPS(deps); +// ccl::impl_dispatch disp; +// return disp(comm)->sparse_allreduce(send_ind_buf, send_ind_count, +// send_val_buf, send_val_count, +// recv_ind_buf, recv_ind_count, +// recv_val_buf, recv_val_count, +// reduction, +// disp(default_stream), attr, deps); +// } + +} // namespace preview + +namespace v1 { + +// API force instantiations for Operations +API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(int8_t); +API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(uint8_t); +API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(int16_t); +API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(uint16_t); +API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(int32_t); +API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(uint32_t); +API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(int64_t); +API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(uint64_t); +API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(float); +API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(double); + +#ifdef CCL_ENABLE_SYCL +#ifndef COMMA +#define COMMA , +#endif + +API_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); +API_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); +API_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); +API_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); +API_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); +API_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); +API_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); +API_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); +API_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); +API_COMM_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer); + +#undef COMMA +#endif // CCL_ENABLE_SYCL + +} // namespace v1 + +namespace preview { + +API_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int32_t, float); +API_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int32_t, ccl::bfloat16); +API_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int64_t, float); +API_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(int64_t, ccl::bfloat16); + +// #ifdef CCL_ENABLE_SYCL +// #ifndef COMMA +// #define COMMA , +// #endif +// API_COMM_SPARSE_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer, +// cl::sycl::buffer); +// API_COMM_SPARSE_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer, +// cl::sycl::buffer); + +// API_COMM_SPARSE_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer, +// cl::sycl::buffer); +// API_COMM_SPARSE_OP_REF_EXPLICIT_INSTANTIATION(cl::sycl::buffer, +// cl::sycl::buffer); +// #undef COMMA +// #endif //CCL_ENABLE_SYCL + +} // namespace preview + +} // namespace ccl diff --git a/src/ccl_api_functions_generators.hpp b/src/ccl_api_functions_generators.hpp index b02c2bfd1..0fc150dbb 100644 --- a/src/ccl_api_functions_generators.hpp +++ b/src/ccl_api_functions_generators.hpp @@ -13,424 +13,426 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -namespace ccl { - -#define CREATE_OP_ATTR_INSTANTIATION(attr) template attr CCL_API create_operation_attr(); - -/******************** DEVICE COMMUNICATOR ********************/ - -/** - * Generating API types for collective operations - * of the device communicator class (communicator) - */ -#define API_DEVICE_COMM_OP_PTR_EXPLICIT_INSTANTIATION(BufferType) \ -\ - template event CCL_API allgatherv(const BufferType* send_buf, \ - size_t send_count, \ - BufferType* recv_buf, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const stream& op_stream, \ - const allgatherv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API allgatherv(const BufferType* send_buf, \ - size_t send_count, \ - BufferType* recv_buf, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const allgatherv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API allgatherv(const BufferType* send_buf, \ - size_t send_count, \ - vector_class& recv_bufs, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const stream& op_stream, \ - const allgatherv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API allgatherv(const BufferType* send_buf, \ - size_t send_count, \ - vector_class& recv_bufs, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const allgatherv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API allreduce(const BufferType* send_buf, \ - BufferType* recv_buf, \ - size_t count, \ - reduction reduction, \ - const communicator& comm, \ - const stream& op_stream, \ - const allreduce_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API allreduce(const BufferType* send_buf, \ - BufferType* recv_buf, \ - size_t count, \ - reduction reduction, \ - const communicator& comm, \ - const allreduce_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoall(const BufferType* send_buf, \ - BufferType* recv_buf, \ - size_t count, \ - const communicator& comm, \ - const stream& op_stream, \ - const alltoall_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoall(const BufferType* send_buf, \ - BufferType* recv_buf, \ - size_t count, \ - const communicator& comm, \ - const alltoall_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoall(const vector_class& send_buf, \ - const vector_class& recv_buf, \ - size_t count, \ - const communicator& comm, \ - const stream& op_stream, \ - const alltoall_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoall(const vector_class& send_buf, \ - const vector_class& recv_buf, \ - size_t count, \ - const communicator& comm, \ - const alltoall_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoallv(const BufferType* send_buf, \ - const vector_class& send_counts, \ - BufferType* recv_buf, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const stream& op_stream, \ - const alltoallv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoallv(const BufferType* send_buf, \ - const vector_class& send_counts, \ - BufferType* recv_buf, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const alltoallv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoallv(const vector_class& send_bufs, \ - const vector_class& send_counts, \ - const vector_class& recv_bufs, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const stream& op_stream, \ - const alltoallv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoallv(const vector_class& send_bufs, \ - const vector_class& send_counts, \ - const vector_class& recv_bufs, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const alltoallv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API broadcast(BufferType* buf, \ - size_t count, \ - size_t root, \ - const communicator& comm, \ - const stream& op_stream, \ - const broadcast_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API broadcast(BufferType* buf, \ - size_t count, \ - size_t root, \ - const communicator& comm, \ - const broadcast_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API reduce(const BufferType* send_buf, \ - BufferType* recv_buf, \ - size_t count, \ - reduction reduction, \ - size_t root, \ - const communicator& comm, \ - const stream& op_stream, \ - const reduce_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API reduce(const BufferType* send_buf, \ - BufferType* recv_buf, \ - size_t count, \ - reduction reduction, \ - size_t root, \ - const communicator& comm, \ - const reduce_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API reduce_scatter(const BufferType* send_buf, \ - BufferType* recv_buf, \ - size_t recv_count, \ - reduction reduction, \ - const communicator& comm, \ - const stream& op_stream, \ - const reduce_scatter_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API reduce_scatter(const BufferType* send_buf, \ - BufferType* recv_buf, \ - size_t recv_count, \ - reduction reduction, \ - const communicator& comm, \ - const reduce_scatter_attr& attr, \ - const vector_class& deps); - -#define API_DEVICE_COMM_OP_REF_EXPLICIT_INSTANTIATION(BufferObjectType) \ -\ - template event CCL_API allgatherv(const BufferObjectType& send_buf, \ - size_t send_count, \ - BufferObjectType& recv_buf, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const stream& op_stream, \ - const allgatherv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API allgatherv(const BufferObjectType& send_buf, \ - size_t send_count, \ - BufferObjectType& recv_buf, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const allgatherv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API allgatherv( \ - const BufferObjectType& send_buf, \ - size_t send_count, \ - vector_class>& recv_bufs, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const stream& op_stream, \ - const allgatherv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API allgatherv( \ - const BufferObjectType& send_buf, \ - size_t send_count, \ - vector_class>& recv_bufs, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const allgatherv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API allreduce(const BufferObjectType& send_buf, \ - BufferObjectType& recv_buf, \ - size_t count, \ - reduction reduction, \ - const communicator& comm, \ - const stream& op_stream, \ - const allreduce_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API allreduce(const BufferObjectType& send_buf, \ - BufferObjectType& recv_buf, \ - size_t count, \ - reduction reduction, \ - const communicator& comm, \ - const allreduce_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoall(const BufferObjectType& send_buf, \ - BufferObjectType& recv_buf, \ - size_t count, \ - const communicator& comm, \ - const stream& op_stream, \ - const alltoall_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoall(const BufferObjectType& send_buf, \ - BufferObjectType& recv_buf, \ - size_t count, \ - const communicator& comm, \ - const alltoall_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoall( \ - const vector_class>& send_buf, \ - const vector_class>& recv_buf, \ - size_t count, \ - const communicator& comm, \ - const stream& op_stream, \ - const alltoall_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoall( \ - const vector_class>& send_buf, \ - const vector_class>& recv_buf, \ - size_t count, \ - const communicator& comm, \ - const alltoall_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoallv(const BufferObjectType& send_buf, \ - const vector_class& send_counts, \ - BufferObjectType& recv_buf, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const stream& op_stream, \ - const alltoallv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoallv(const BufferObjectType& send_buf, \ - const vector_class& send_counts, \ - BufferObjectType& recv_buf, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const alltoallv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoallv( \ - const vector_class>& send_bufs, \ - const vector_class& send_counts, \ - const vector_class>& recv_bufs, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const stream& op_stream, \ - const alltoallv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API alltoallv( \ - const vector_class>& send_bufs, \ - const vector_class& send_counts, \ - const vector_class>& recv_bufs, \ - const vector_class& recv_counts, \ - const communicator& comm, \ - const alltoallv_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API broadcast(BufferObjectType& buf, \ - size_t count, \ - size_t root, \ - const communicator& comm, \ - const stream& op_stream, \ - const broadcast_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API broadcast(BufferObjectType& buf, \ - size_t count, \ - size_t root, \ - const communicator& comm, \ - const broadcast_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API reduce(const BufferObjectType& send_buf, \ - BufferObjectType& recv_buf, \ - size_t count, \ - reduction reduction, \ - size_t root, \ - const communicator& comm, \ - const stream& op_stream, \ - const reduce_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API reduce(const BufferObjectType& send_buf, \ - BufferObjectType& recv_buf, \ - size_t count, \ - reduction reduction, \ - size_t root, \ - const communicator& comm, \ - const reduce_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API reduce_scatter(const BufferObjectType& send_buf, \ - BufferObjectType& recv_buf, \ - size_t recv_count, \ - reduction reduction, \ - const communicator& comm, \ - const stream& op_stream, \ - const reduce_scatter_attr& attr, \ - const vector_class& deps); \ -\ - template event CCL_API reduce_scatter(const BufferObjectType& send_buf, \ - BufferObjectType& recv_buf, \ - size_t recv_count, \ - reduction reduction, \ - const communicator& comm, \ - const reduce_scatter_attr& attr, \ - const vector_class& deps); - -namespace preview { - -#define API_DEVICE_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(index_type, value_type) \ -\ - template ccl::event CCL_API sparse_allreduce(const index_type* send_ind_buf, \ - size_t send_ind_count, \ - const value_type* send_val_buf, \ - size_t send_val_count, \ - index_type* recv_ind_buf, \ - size_t recv_ind_count, \ - value_type* recv_val_buf, \ - size_t recv_val_count, \ - ccl::reduction reduction, \ - const ccl::communicator& comm, \ - const ccl::stream& op_stream, \ - const ccl::sparse_allreduce_attr& attr, \ - const ccl::vector_class& deps); \ -\ - template ccl::event CCL_API sparse_allreduce(const index_type* send_ind_buf, \ - size_t send_ind_count, \ - const value_type* send_val_buf, \ - size_t send_val_count, \ - index_type* recv_ind_buf, \ - size_t recv_ind_count, \ - value_type* recv_val_buf, \ - size_t recv_val_count, \ - ccl::reduction reduction, \ - const ccl::communicator& comm, \ - const ccl::sparse_allreduce_attr& attr, \ - const ccl::vector_class& deps); - -/* -#define API_DEVICE_COMM_SPARSE_OP_REF_EXPLICIT_INSTANTIATION(index_object_type, value_object_type) \ -\ -template ccl::event CCL_API \ -sparse_allreduce(const index_object_type& send_ind_buf, \ - size_t send_ind_count, \ - const value_object_type& send_val_buf, \ - size_t send_val_count, \ - index_object_type& recv_ind_buf, \ - size_t recv_ind_count, \ - value_object_type& recv_val_buf, \ - size_t recv_val_count, \ - ccl::reduction reduction, \ - const ccl::communicator& comm, \ - const ccl::stream& op_stream, \ - const ccl::sparse_allreduce_attr& attr, \ - const ccl::vector_class& deps); \ -\ -template ccl::event CCL_API \ -sparse_allreduce(const index_object_type& send_ind_buf, \ - size_t send_ind_count, \ - const value_object_type& send_val_buf, \ - size_t send_val_count, \ - index_object_type& recv_ind_buf, \ - size_t recv_ind_count, \ - value_object_type& recv_val_buf, \ - size_t recv_val_count, \ - ccl::reduction reduction, \ - const ccl::communicator& comm, \ - const ccl::sparse_allreduce_attr& attr, \ - const ccl::vector_class& deps); -*/ - -} // namespace preview - -} // namespace ccl +#pragma once + +namespace ccl { + +namespace v1 { + +/******************** COMMUNICATOR ********************/ + +/** + * Generating API types for collective operations + * of the communicator class (communicator) + */ +#define API_COMM_OP_PTR_EXPLICIT_INSTANTIATION(BufferType) \ +\ + template event CCL_API allgatherv(const BufferType* send_buf, \ + size_t send_count, \ + BufferType* recv_buf, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const stream& op_stream, \ + const allgatherv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API allgatherv(const BufferType* send_buf, \ + size_t send_count, \ + BufferType* recv_buf, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const allgatherv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API allgatherv(const BufferType* send_buf, \ + size_t send_count, \ + vector_class& recv_bufs, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const stream& op_stream, \ + const allgatherv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API allgatherv(const BufferType* send_buf, \ + size_t send_count, \ + vector_class& recv_bufs, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const allgatherv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API allreduce(const BufferType* send_buf, \ + BufferType* recv_buf, \ + size_t count, \ + reduction reduction, \ + const communicator& comm, \ + const stream& op_stream, \ + const allreduce_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API allreduce(const BufferType* send_buf, \ + BufferType* recv_buf, \ + size_t count, \ + reduction reduction, \ + const communicator& comm, \ + const allreduce_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoall(const BufferType* send_buf, \ + BufferType* recv_buf, \ + size_t count, \ + const communicator& comm, \ + const stream& op_stream, \ + const alltoall_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoall(const BufferType* send_buf, \ + BufferType* recv_buf, \ + size_t count, \ + const communicator& comm, \ + const alltoall_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoall(const vector_class& send_buf, \ + const vector_class& recv_buf, \ + size_t count, \ + const communicator& comm, \ + const stream& op_stream, \ + const alltoall_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoall(const vector_class& send_buf, \ + const vector_class& recv_buf, \ + size_t count, \ + const communicator& comm, \ + const alltoall_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoallv(const BufferType* send_buf, \ + const vector_class& send_counts, \ + BufferType* recv_buf, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const stream& op_stream, \ + const alltoallv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoallv(const BufferType* send_buf, \ + const vector_class& send_counts, \ + BufferType* recv_buf, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const alltoallv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoallv(const vector_class& send_bufs, \ + const vector_class& send_counts, \ + const vector_class& recv_bufs, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const stream& op_stream, \ + const alltoallv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoallv(const vector_class& send_bufs, \ + const vector_class& send_counts, \ + const vector_class& recv_bufs, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const alltoallv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API broadcast(BufferType* buf, \ + size_t count, \ + int root, \ + const communicator& comm, \ + const stream& op_stream, \ + const broadcast_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API broadcast(BufferType* buf, \ + size_t count, \ + int root, \ + const communicator& comm, \ + const broadcast_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API reduce(const BufferType* send_buf, \ + BufferType* recv_buf, \ + size_t count, \ + reduction reduction, \ + int root, \ + const communicator& comm, \ + const stream& op_stream, \ + const reduce_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API reduce(const BufferType* send_buf, \ + BufferType* recv_buf, \ + size_t count, \ + reduction reduction, \ + int root, \ + const communicator& comm, \ + const reduce_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API reduce_scatter(const BufferType* send_buf, \ + BufferType* recv_buf, \ + size_t recv_count, \ + reduction reduction, \ + const communicator& comm, \ + const stream& op_stream, \ + const reduce_scatter_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API reduce_scatter(const BufferType* send_buf, \ + BufferType* recv_buf, \ + size_t recv_count, \ + reduction reduction, \ + const communicator& comm, \ + const reduce_scatter_attr& attr, \ + const vector_class& deps); + +#define API_COMM_OP_REF_EXPLICIT_INSTANTIATION(BufferObjectType) \ +\ + template event CCL_API allgatherv(const BufferObjectType& send_buf, \ + size_t send_count, \ + BufferObjectType& recv_buf, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const stream& op_stream, \ + const allgatherv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API allgatherv(const BufferObjectType& send_buf, \ + size_t send_count, \ + BufferObjectType& recv_buf, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const allgatherv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API allgatherv( \ + const BufferObjectType& send_buf, \ + size_t send_count, \ + vector_class>& recv_bufs, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const stream& op_stream, \ + const allgatherv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API allgatherv( \ + const BufferObjectType& send_buf, \ + size_t send_count, \ + vector_class>& recv_bufs, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const allgatherv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API allreduce(const BufferObjectType& send_buf, \ + BufferObjectType& recv_buf, \ + size_t count, \ + reduction reduction, \ + const communicator& comm, \ + const stream& op_stream, \ + const allreduce_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API allreduce(const BufferObjectType& send_buf, \ + BufferObjectType& recv_buf, \ + size_t count, \ + reduction reduction, \ + const communicator& comm, \ + const allreduce_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoall(const BufferObjectType& send_buf, \ + BufferObjectType& recv_buf, \ + size_t count, \ + const communicator& comm, \ + const stream& op_stream, \ + const alltoall_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoall(const BufferObjectType& send_buf, \ + BufferObjectType& recv_buf, \ + size_t count, \ + const communicator& comm, \ + const alltoall_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoall( \ + const vector_class>& send_buf, \ + const vector_class>& recv_buf, \ + size_t count, \ + const communicator& comm, \ + const stream& op_stream, \ + const alltoall_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoall( \ + const vector_class>& send_buf, \ + const vector_class>& recv_buf, \ + size_t count, \ + const communicator& comm, \ + const alltoall_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoallv(const BufferObjectType& send_buf, \ + const vector_class& send_counts, \ + BufferObjectType& recv_buf, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const stream& op_stream, \ + const alltoallv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoallv(const BufferObjectType& send_buf, \ + const vector_class& send_counts, \ + BufferObjectType& recv_buf, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const alltoallv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoallv( \ + const vector_class>& send_bufs, \ + const vector_class& send_counts, \ + const vector_class>& recv_bufs, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const stream& op_stream, \ + const alltoallv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API alltoallv( \ + const vector_class>& send_bufs, \ + const vector_class& send_counts, \ + const vector_class>& recv_bufs, \ + const vector_class& recv_counts, \ + const communicator& comm, \ + const alltoallv_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API broadcast(BufferObjectType& buf, \ + size_t count, \ + int root, \ + const communicator& comm, \ + const stream& op_stream, \ + const broadcast_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API broadcast(BufferObjectType& buf, \ + size_t count, \ + int root, \ + const communicator& comm, \ + const broadcast_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API reduce(const BufferObjectType& send_buf, \ + BufferObjectType& recv_buf, \ + size_t count, \ + reduction reduction, \ + int root, \ + const communicator& comm, \ + const stream& op_stream, \ + const reduce_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API reduce(const BufferObjectType& send_buf, \ + BufferObjectType& recv_buf, \ + size_t count, \ + reduction reduction, \ + int root, \ + const communicator& comm, \ + const reduce_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API reduce_scatter(const BufferObjectType& send_buf, \ + BufferObjectType& recv_buf, \ + size_t recv_count, \ + reduction reduction, \ + const communicator& comm, \ + const stream& op_stream, \ + const reduce_scatter_attr& attr, \ + const vector_class& deps); \ +\ + template event CCL_API reduce_scatter(const BufferObjectType& send_buf, \ + BufferObjectType& recv_buf, \ + size_t recv_count, \ + reduction reduction, \ + const communicator& comm, \ + const reduce_scatter_attr& attr, \ + const vector_class& deps); + +} // namespace v1 + +namespace preview { + +#define API_COMM_SPARSE_OP_PTR_EXPLICIT_INSTANTIATION(index_type, value_type) \ +\ + template ccl::event CCL_API sparse_allreduce(const index_type* send_ind_buf, \ + size_t send_ind_count, \ + const value_type* send_val_buf, \ + size_t send_val_count, \ + index_type* recv_ind_buf, \ + size_t recv_ind_count, \ + value_type* recv_val_buf, \ + size_t recv_val_count, \ + ccl::reduction reduction, \ + const ccl::communicator& comm, \ + const ccl::stream& op_stream, \ + const ccl::sparse_allreduce_attr& attr, \ + const ccl::vector_class& deps); \ +\ + template ccl::event CCL_API sparse_allreduce(const index_type* send_ind_buf, \ + size_t send_ind_count, \ + const value_type* send_val_buf, \ + size_t send_val_count, \ + index_type* recv_ind_buf, \ + size_t recv_ind_count, \ + value_type* recv_val_buf, \ + size_t recv_val_count, \ + ccl::reduction reduction, \ + const ccl::communicator& comm, \ + const ccl::sparse_allreduce_attr& attr, \ + const ccl::vector_class& deps); + +/* +#define API_COMM_SPARSE_OP_REF_EXPLICIT_INSTANTIATION(index_object_type, value_object_type) \ +\ +template ccl::event CCL_API \ +sparse_allreduce(const index_object_type& send_ind_buf, \ + size_t send_ind_count, \ + const value_object_type& send_val_buf, \ + size_t send_val_count, \ + index_object_type& recv_ind_buf, \ + size_t recv_ind_count, \ + value_object_type& recv_val_buf, \ + size_t recv_val_count, \ + ccl::reduction reduction, \ + const ccl::communicator& comm, \ + const ccl::stream& op_stream, \ + const ccl::sparse_allreduce_attr& attr, \ + const ccl::vector_class& deps); \ +\ +template ccl::event CCL_API \ +sparse_allreduce(const index_object_type& send_ind_buf, \ + size_t send_ind_count, \ + const value_object_type& send_val_buf, \ + size_t send_val_count, \ + index_object_type& recv_ind_buf, \ + size_t recv_ind_count, \ + value_object_type& recv_val_buf, \ + size_t recv_val_count, \ + ccl::reduction reduction, \ + const ccl::communicator& comm, \ + const ccl::sparse_allreduce_attr& attr, \ + const ccl::vector_class& deps); +*/ + +} // namespace preview + +} // namespace ccl diff --git a/src/ccl_app_api_coll_attr.cpp b/src/ccl_app_api_coll_attr.cpp index 63b6829f0..1c519531b 100644 --- a/src/ccl_app_api_coll_attr.cpp +++ b/src/ccl_app_api_coll_attr.cpp @@ -13,28 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_aliases.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_coll_attr.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" +#include "oneapi/ccl/coll_attr.hpp" // Core file with PIMPL implementation #include "coll_attr_impl.hpp" #include "coll/coll_attributes.hpp" namespace ccl { + +namespace v1 { + #define COMMA , #define API_FORCE_INSTANTIATION_SET(class_name, IN_attrType, IN_attrId, IN_Value) \ template CCL_API \ - typename details::ccl_api_type_attr_traits::return_type \ + typename detail::ccl_api_type_attr_traits::return_type \ class_name::set(const IN_Value& v); #define API_FORCE_INSTANTIATION_GET(class_name, IN_attrType, IN_attrId) \ - template CCL_API const typename details::ccl_api_type_attr_traits::return_type& \ + template CCL_API const typename detail::ccl_api_type_attr_traits::return_type& \ class_name::get() const; #define API_FORCE_INSTANTIATION(class_name, IN_attrType, IN_attrId, IN_Value) \ @@ -44,10 +47,10 @@ namespace ccl { #define COMMON_API_FORCE_INSTANTIATION(class_name) \ API_FORCE_INSTANTIATION( \ class_name, operation_attr_id, operation_attr_id::version, ccl::library_version) \ - API_FORCE_INSTANTIATION( \ - class_name, operation_attr_id, operation_attr_id::prologue_fn, ccl::prologue_fn) \ - API_FORCE_INSTANTIATION( \ - class_name, operation_attr_id, operation_attr_id::epilogue_fn, ccl::epilogue_fn) \ + /*API_FORCE_INSTANTIATION(*/ \ + /*class_name, operation_attr_id, operation_attr_id::prologue_fn, ccl::prologue_fn)*/ \ + /*API_FORCE_INSTANTIATION(*/ \ + /*class_name, operation_attr_id, operation_attr_id::epilogue_fn, ccl::epilogue_fn)*/ \ \ API_FORCE_INSTANTIATION_SET( \ class_name, operation_attr_id, operation_attr_id::priority, size_t) \ @@ -69,9 +72,9 @@ CCL_API allgatherv_attr::allgatherv_attr(allgatherv_attr&& src) : base_t(std::mo CCL_API allgatherv_attr::allgatherv_attr(const allgatherv_attr& src) : base_t(src) {} CCL_API allgatherv_attr::allgatherv_attr( - const typename details::ccl_api_type_attr_traits::type& version) - : base_t(std::shared_ptr(new impl_t(version))) {} + const typename detail::ccl_api_type_attr_traits::type& version) + : base_t(impl_value_t(new impl_t(version))) {} CCL_API allgatherv_attr::~allgatherv_attr() {} @@ -83,9 +86,9 @@ CCL_API allreduce_attr::allreduce_attr(allreduce_attr&& src) : base_t(std::move( CCL_API allreduce_attr::allreduce_attr(const allreduce_attr& src) : base_t(src) {} CCL_API allreduce_attr::allreduce_attr( - const typename details::ccl_api_type_attr_traits::type& version) - : base_t(std::shared_ptr(new impl_t(version))) {} + const typename detail::ccl_api_type_attr_traits::type& version) + : base_t(impl_value_t(new impl_t(version))) {} CCL_API allreduce_attr::~allreduce_attr() {} @@ -97,9 +100,9 @@ CCL_API alltoall_attr::alltoall_attr(alltoall_attr&& src) : base_t(std::move(src CCL_API alltoall_attr::alltoall_attr(const alltoall_attr& src) : base_t(src) {} CCL_API alltoall_attr::alltoall_attr( - const typename details::ccl_api_type_attr_traits::type& version) - : base_t(std::shared_ptr(new impl_t(version))) {} + const typename detail::ccl_api_type_attr_traits::type& version) + : base_t(impl_value_t(new impl_t(version))) {} CCL_API alltoall_attr::~alltoall_attr() {} @@ -111,9 +114,9 @@ CCL_API alltoallv_attr::alltoallv_attr(alltoallv_attr&& src) : base_t(std::move( CCL_API alltoallv_attr::alltoallv_attr(const alltoallv_attr& src) : base_t(src) {} CCL_API alltoallv_attr::alltoallv_attr( - const typename details::ccl_api_type_attr_traits::type& version) - : base_t(std::shared_ptr(new impl_t(version))) {} + const typename detail::ccl_api_type_attr_traits::type& version) + : base_t(impl_value_t(new impl_t(version))) {} CCL_API alltoallv_attr::~alltoallv_attr() {} @@ -125,9 +128,9 @@ CCL_API barrier_attr::barrier_attr(barrier_attr&& src) : base_t(std::move(src)) CCL_API barrier_attr::barrier_attr(const barrier_attr& src) : base_t(src) {} CCL_API barrier_attr::barrier_attr( - const typename details::ccl_api_type_attr_traits::type& version) - : base_t(std::shared_ptr(new impl_t(version))) {} + const typename detail::ccl_api_type_attr_traits::type& version) + : base_t(impl_value_t(new impl_t(version))) {} CCL_API barrier_attr::~barrier_attr() {} @@ -139,9 +142,9 @@ CCL_API broadcast_attr::broadcast_attr(broadcast_attr&& src) : base_t(std::move( CCL_API broadcast_attr::broadcast_attr(const broadcast_attr& src) : base_t(src) {} CCL_API broadcast_attr::broadcast_attr( - const typename details::ccl_api_type_attr_traits::type& version) - : base_t(std::shared_ptr(new impl_t(version))) {} + const typename detail::ccl_api_type_attr_traits::type& version) + : base_t(impl_value_t(new impl_t(version))) {} CCL_API broadcast_attr::~broadcast_attr() {} @@ -153,9 +156,9 @@ CCL_API reduce_attr::reduce_attr(reduce_attr&& src) : base_t(std::move(src)) {} CCL_API reduce_attr::reduce_attr(const reduce_attr& src) : base_t(src) {} CCL_API reduce_attr::reduce_attr( - const typename details::ccl_api_type_attr_traits::type& version) - : base_t(std::shared_ptr(new impl_t(version))) {} + const typename detail::ccl_api_type_attr_traits::type& version) + : base_t(impl_value_t(new impl_t(version))) {} CCL_API reduce_attr::~reduce_attr() {} @@ -168,9 +171,9 @@ CCL_API reduce_scatter_attr::reduce_scatter_attr(reduce_scatter_attr&& src) CCL_API reduce_scatter_attr::reduce_scatter_attr(const reduce_scatter_attr& src) : base_t(src) {} CCL_API reduce_scatter_attr::reduce_scatter_attr( - const typename details::ccl_api_type_attr_traits::type& version) - : base_t(std::shared_ptr(new impl_t(version))) {} + const typename detail::ccl_api_type_attr_traits::type& version) + : base_t(impl_value_t(new impl_t(version))) {} CCL_API reduce_scatter_attr::~reduce_scatter_attr() {} @@ -184,9 +187,9 @@ CCL_API sparse_allreduce_attr::sparse_allreduce_attr(const sparse_allreduce_attr : base_t(src) {} CCL_API sparse_allreduce_attr::sparse_allreduce_attr( - const typename details::ccl_api_type_attr_traits::type& version) - : base_t(std::shared_ptr(new impl_t(version))) {} + const typename detail::ccl_api_type_attr_traits::type& version) + : base_t(impl_value_t(new impl_t(version))) {} CCL_API sparse_allreduce_attr::~sparse_allreduce_attr() {} @@ -195,14 +198,14 @@ CCL_API const void* sparse_allreduce_attr::setset_attribute_value( v, - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } template <> CCL_API const void* const& sparse_allreduce_attr::get() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } /** @@ -248,4 +251,7 @@ API_FORCE_INSTANTIATION(sparse_allreduce_attr, #undef API_FORCE_INSTANTIATION #undef COMMON_API_FORCE_INSTANTIATION #undef COMMA + +} // namespace v1 + } // namespace ccl diff --git a/src/ccl_app_api_comm_attr.cpp b/src/ccl_app_api_comm_attr.cpp new file mode 100644 index 000000000..5468c6bd5 --- /dev/null +++ b/src/ccl_app_api_comm_attr.cpp @@ -0,0 +1,77 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/comm_attr_ids.hpp" +#include "oneapi/ccl/comm_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_attr.hpp" + +// Core file with PIMPL implementation +#include "common/comm/comm_common_attr.hpp" +#include "comm_attr_impl.hpp" + +namespace ccl { + +namespace v1 { + +#define API_FORCE_INSTANTIATION(class_name, IN_attrId, IN_Value, OUT_Traits_Value) \ + template CCL_API IN_Value class_name::set(const IN_Value& v); \ +\ + template CCL_API const typename OUT_Traits_Value::type& \ + class_name::get() const; \ +\ + template CCL_API bool class_name::is_valid() const noexcept; + +/** + * comm_attr attributes definition + */ +CCL_API comm_attr::comm_attr(ccl_empty_attr) + : base_t(impl_value_t(new impl_t(ccl_empty_attr::version))) {} +CCL_API comm_attr::comm_attr(comm_attr&& src) : base_t(std::move(src)) {} + +CCL_API comm_attr::comm_attr(const comm_attr& src) : base_t(src) {} + +CCL_API comm_attr::comm_attr( + const typename detail::ccl_api_type_attr_traits::return_type& version) + : base_t(impl_value_t(new impl_t(version))) {} + +CCL_API comm_attr::~comm_attr() noexcept {} + +CCL_API comm_attr& comm_attr::operator=(const comm_attr& src) { + this->get_impl() = src.get_impl(); + return *this; +} + +CCL_API comm_attr& comm_attr::operator=(comm_attr&& src) { + if (src.get_impl() != this->get_impl()) { + src.get_impl().swap(this->get_impl()); + src.get_impl().reset(); + } + return *this; +} + +API_FORCE_INSTANTIATION(comm_attr, + comm_attr_id::version, + ccl::library_version, + detail::ccl_api_type_attr_traits) + +#undef API_FORCE_INSTANTIATION + +} // namespace v1 + +} // namespace ccl diff --git a/src/ccl_app_api_comm_split_attr.cpp b/src/ccl_app_api_comm_split_attr.cpp index 3a7c2dba8..71f4b70dd 100644 --- a/src/ccl_app_api_comm_split_attr.cpp +++ b/src/ccl_app_api_comm_split_attr.cpp @@ -13,24 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_aliases.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_comm_split_attr_ids.hpp" -#include "oneapi/ccl/ccl_comm_split_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_comm_split_attr.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/comm_split_attr_ids.hpp" +#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_split_attr.hpp" // Core file with PIMPL implementation #include "common/comm/comm_split_common_attr.hpp" #include "comm_split_attr_impl.hpp" namespace ccl { -#define COMMA , + +namespace v1 { + #define API_FORCE_INSTANTIATION(class_name, IN_attrId, IN_Value, OUT_Traits_Value) \ template CCL_API IN_Value class_name::set(const IN_Value& v); \ \ - template CCL_API const typename details::OUT_Traits_Value::type& \ + template CCL_API const typename OUT_Traits_Value::type& \ class_name::get() const; \ \ template CCL_API bool class_name::is_valid() const noexcept; @@ -39,17 +40,15 @@ namespace ccl { * comm_split_attr attributes definition */ CCL_API comm_split_attr::comm_split_attr(ccl_empty_attr) - : base_t(std::shared_ptr(new impl_t(ccl_empty_attr::version))) {} -CCL_API comm_split_attr::comm_split_attr(comm_split_attr&& src) - : base_t(std::move(src)) {} + : base_t(impl_value_t(new impl_t(ccl_empty_attr::version))) {} +CCL_API comm_split_attr::comm_split_attr(comm_split_attr&& src) : base_t(std::move(src)) {} -CCL_API comm_split_attr::comm_split_attr(const comm_split_attr& src) - : base_t(src) {} +CCL_API comm_split_attr::comm_split_attr(const comm_split_attr& src) : base_t(src) {} CCL_API comm_split_attr::comm_split_attr( - const typename details::ccl_api_type_attr_traits::type& version) - : base_t(std::shared_ptr(new impl_t(version))) {} + const typename detail::ccl_api_type_attr_traits::type& version) + : base_t(impl_value_t(new impl_t(version))) {} CCL_API comm_split_attr::~comm_split_attr() noexcept {} @@ -65,19 +64,22 @@ CCL_API comm_split_attr& comm_split_attr::operator=(comm_split_attr&& src) { } return *this; } + API_FORCE_INSTANTIATION(comm_split_attr, comm_split_attr_id::color, int, - ccl_api_type_attr_traits) + detail::ccl_api_type_attr_traits) API_FORCE_INSTANTIATION(comm_split_attr, comm_split_attr_id::group, - group_split_type, - ccl_api_type_attr_traits) + split_group, + detail::ccl_api_type_attr_traits) API_FORCE_INSTANTIATION(comm_split_attr, comm_split_attr_id::version, ccl::library_version, - ccl_api_type_attr_traits) + detail::ccl_api_type_attr_traits) #undef API_FORCE_INSTANTIATION -#undef COMMA + +} // namespace v1 + } // namespace ccl diff --git a/src/ccl_app_api_datatype_attr.cpp b/src/ccl_app_api_datatype_attr.cpp index ed7ca3bc5..0f1dcfacc 100644 --- a/src/ccl_app_api_datatype_attr.cpp +++ b/src/ccl_app_api_datatype_attr.cpp @@ -13,75 +13,74 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_aliases.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_datatype_attr_ids.hpp" -#include "oneapi/ccl/ccl_datatype_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_datatype_attr.hpp" - -// Core file with PIMPL implementation -#include "common/datatype/datatype_attr.hpp" -#include "datatype_attr_impl.hpp" - -namespace ccl { - -#define COMMA , -#define API_FORCE_SETTER_INSTANTIATION(class_name, IN_attrId, IN_Value, OUT_Traits_Value) \ - template CCL_API IN_Value class_name::set(const IN_Value& v); - -#define API_FORCE_GETTER_INSTANTIATION(class_name, IN_attrId, IN_Value, OUT_Traits_Value) \ - template CCL_API const typename details::OUT_Traits_Value::return_type& \ - class_name::get() const; - -/** - * datatype_attr attributes definition - */ -CCL_API datatype_attr::datatype_attr(datatype_attr&& src) : base_t(std::move(src)) {} - -CCL_API datatype_attr::datatype_attr(const datatype_attr& src) : base_t(src) {} - -CCL_API datatype_attr::datatype_attr( - const typename details::ccl_api_type_attr_traits::return_type& - version) - : base_t(std::shared_ptr(new impl_t(version))) {} - -CCL_API datatype_attr::~datatype_attr() noexcept {} - -CCL_API datatype_attr& datatype_attr::operator=(const datatype_attr& src) { - this->get_impl() = src.get_impl(); - return *this; -} - -CCL_API datatype_attr& datatype_attr::operator=(datatype_attr&& src) { - if (src.get_impl() != this->get_impl()) { - src.get_impl().swap(this->get_impl()); - src.get_impl().reset(); - } - return *this; -} - -API_FORCE_SETTER_INSTANTIATION(datatype_attr, - datatype_attr_id::size, - int, - ccl_api_type_attr_traits); -API_FORCE_SETTER_INSTANTIATION(datatype_attr, - datatype_attr_id::size, - size_t, - ccl_api_type_attr_traits); -API_FORCE_GETTER_INSTANTIATION(datatype_attr, - datatype_attr_id::size, - size_t, - ccl_api_type_attr_traits); -API_FORCE_GETTER_INSTANTIATION(datatype_attr, - datatype_attr_id::version, - ccl::library_version, - ccl_api_type_attr_traits); - -#undef API_FORCE_SETTER_INSTANTIATION -#undef API_FORCE_GETTER_INSTANTIATION -#undef COMMA - -} // namespace ccl +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/datatype_attr_ids.hpp" +#include "oneapi/ccl/datatype_attr_ids_traits.hpp" +#include "oneapi/ccl/datatype_attr.hpp" + +// Core file with PIMPL implementation +#include "common/datatype/datatype_attr.hpp" +#include "datatype_attr_impl.hpp" + +namespace ccl { + +namespace v1 { + +#define API_FORCE_SETTER_INSTANTIATION(class_name, IN_attrId, IN_Value, OUT_Traits_Value) \ + template CCL_API IN_Value class_name::set(const IN_Value& v); + +#define API_FORCE_GETTER_INSTANTIATION(class_name, IN_attrId, OUT_Traits_Value) \ + template CCL_API const typename OUT_Traits_Value::return_type& \ + class_name::get() const; + +/** + * datatype_attr attributes definition + */ +CCL_API datatype_attr::datatype_attr(datatype_attr&& src) : base_t(std::move(src)) {} + +CCL_API datatype_attr::datatype_attr(const datatype_attr& src) : base_t(src) {} + +CCL_API datatype_attr::datatype_attr( + const typename detail::ccl_api_type_attr_traits::return_type& + version) + : base_t(impl_value_t(new impl_t(version))) {} + +CCL_API datatype_attr::~datatype_attr() noexcept {} + +CCL_API datatype_attr& datatype_attr::operator=(const datatype_attr& src) { + this->get_impl() = src.get_impl(); + return *this; +} + +CCL_API datatype_attr& datatype_attr::operator=(datatype_attr&& src) { + if (src.get_impl() != this->get_impl()) { + src.get_impl().swap(this->get_impl()); + src.get_impl().reset(); + } + return *this; +} + +API_FORCE_SETTER_INSTANTIATION(datatype_attr, + datatype_attr_id::size, + int, + detail::ccl_api_type_attr_traits); +API_FORCE_SETTER_INSTANTIATION(datatype_attr, + datatype_attr_id::size, + size_t, + detail::ccl_api_type_attr_traits); +API_FORCE_GETTER_INSTANTIATION(datatype_attr, + datatype_attr_id::size, + detail::ccl_api_type_attr_traits); +API_FORCE_GETTER_INSTANTIATION(datatype_attr, + datatype_attr_id::version, + detail::ccl_api_type_attr_traits); + +#undef API_FORCE_SETTER_INSTANTIATION +#undef API_FORCE_GETTER_INSTANTIATION + +} // namespace v1 + +} // namespace ccl diff --git a/src/ccl_app_api_event.cpp b/src/ccl_app_api_event.cpp index 143cbbdbb..c2ca83a55 100644 --- a/src/ccl_app_api_event.cpp +++ b/src/ccl_app_api_event.cpp @@ -13,70 +13,78 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "common/event/impls/event_impl.hpp" -#include "common/event/impls/empty_event.hpp" -#include "common/event/impls/native_event.hpp" - -namespace ccl { - -CCL_API event::event() noexcept : base_t(impl_value_t(new empty_event_impl())) {} -CCL_API event::event(event&& src) noexcept : base_t(std::move(src)) {} -CCL_API event::event(impl_value_t&& impl) noexcept : base_t(std::move(impl)) {} -CCL_API event::~event() noexcept {} - -CCL_API event& event::operator=(event&& src) noexcept { - if (this->get_impl() != src.get_impl()) { - this->get_impl() = std::move(src.get_impl()); - } - return *this; -} - -bool CCL_API event::operator==(const event& rhs) const noexcept { - return this->get_impl() == rhs.get_impl(); -} - -bool CCL_API event::operator!=(const event& rhs) const noexcept { - return this->get_impl() != rhs.get_impl(); -} - -CCL_API event::operator bool() { - return this->test(); -} - -void CCL_API event::wait() { - get_impl()->wait(); -} - -bool CCL_API event::test() { - return get_impl()->test(); -} - -bool CCL_API event::cancel() { - return get_impl()->cancel(); -} - -CCL_API event::native_t& event::get_native() { - return const_cast(get_impl()->get_native()); -} - -CCL_API const event::native_t& event::get_native() const { - return get_impl()->get_native(); -} - -event CCL_API event::create_from_native(native_t& native_event) { - library_version version; - version.major = CCL_MAJOR_VERSION; - version.minor = CCL_MINOR_VERSION; - version.update = CCL_UPDATE_VERSION; - version.product_status = CCL_PRODUCT_STATUS; - version.build_date = CCL_PRODUCT_BUILD_DATE; - version.full = CCL_PRODUCT_FULL; - - return impl_value_t( - new native_event_impl(native_event, version) - ); -} - -} // namespace ccl +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "common/event/impls/event_impl.hpp" +#include "common/event/impls/empty_event.hpp" +#include "common/event/impls/native_event.hpp" +#include "common/utils/version.hpp" + +namespace ccl { + +namespace v1 { + +CCL_API event::event() noexcept : base_t(impl_value_t(new empty_event_impl())) {} +CCL_API event::event(event&& src) noexcept : base_t(std::move(src)) {} +CCL_API event::event(impl_value_t&& impl) noexcept : base_t(std::move(impl)) {} +CCL_API event::~event() noexcept {} + +CCL_API event& event::operator=(event&& src) noexcept { + if (this->get_impl() != src.get_impl()) { + this->get_impl() = std::move(src.get_impl()); + } + return *this; +} + +bool CCL_API event::operator==(const event& rhs) const noexcept { + return this->get_impl() == rhs.get_impl(); +} + +bool CCL_API event::operator!=(const event& rhs) const noexcept { + return this->get_impl() != rhs.get_impl(); +} + +CCL_API event::operator bool() { + return this->test(); +} + +void CCL_API event::wait() { + get_impl()->wait(); +} + +bool CCL_API event::test() { + return get_impl()->test(); +} + +bool CCL_API event::cancel() { + return get_impl()->cancel(); +} + +CCL_API event::native_t& event::get_native() { + return const_cast(get_impl()->get_native()); +} + +CCL_API const event::native_t& event::get_native() const { + return get_impl()->get_native(); +} + +event CCL_API event::create_from_native(native_t& native_event) { + auto version = utils::get_library_version(); + + auto ev = std::unique_ptr(new ccl_event(native_event, version)); + + return impl_value_t(new native_event_impl(std::move(ev))); +} + +event CCL_API event::create_from_native(native_handle_t native_event_handle, context_t context) { + auto version = utils::get_library_version(); + + auto ev = std::unique_ptr(new ccl_event(native_event_handle, context, version)); + ev->build_from_params(); + + return impl_value_t(new native_event_impl(std::move(ev))); +} + +} // namespace v1 + +} // namespace ccl diff --git a/src/ccl_app_api_init_attr.cpp b/src/ccl_app_api_init_attr.cpp new file mode 100644 index 000000000..2e2b72d8b --- /dev/null +++ b/src/ccl_app_api_init_attr.cpp @@ -0,0 +1,71 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/init_attr_ids.hpp" +#include "oneapi/ccl/init_attr_ids_traits.hpp" +#include "oneapi/ccl/init_attr.hpp" + +// Core file with PIMPL implementation +#include "init_attr_impl.hpp" + +namespace ccl { + +namespace v1 { + +#define API_FORCE_SETTER_INSTANTIATION(class_name, IN_attrId, IN_Value, OUT_Traits_Value) \ + template CCL_API IN_Value class_name::set(const IN_Value& v); + +#define API_FORCE_GETTER_INSTANTIATION(class_name, IN_attrId, OUT_Traits_Value) \ + template CCL_API const typename OUT_Traits_Value::return_type& \ + class_name::get() const; + +/** + * init_attr attributes definition + */ +CCL_API init_attr::init_attr(init_attr&& src) : base_t(std::move(src)) {} + +CCL_API init_attr::init_attr(const init_attr& src) : base_t(src) {} + +CCL_API init_attr::init_attr( + const typename detail::ccl_api_type_attr_traits::return_type& version) + : base_t(impl_value_t(new impl_t(version))) {} + +CCL_API init_attr::~init_attr() noexcept {} + +CCL_API init_attr& init_attr::operator=(const init_attr& src) { + this->get_impl() = src.get_impl(); + return *this; +} + +CCL_API init_attr& init_attr::operator=(init_attr&& src) { + if (src.get_impl() != this->get_impl()) { + src.get_impl().swap(this->get_impl()); + src.get_impl().reset(); + } + return *this; +} + +API_FORCE_GETTER_INSTANTIATION(init_attr, init_attr_id::version, detail::ccl_api_type_attr_traits); + +#undef API_FORCE_SETTER_INSTANTIATION +#undef API_FORCE_GETTER_INSTANTIATION + +} // namespace v1 + +} // namespace ccl diff --git a/src/ccl_app_api_kvs_attr.cpp b/src/ccl_app_api_kvs_attr.cpp new file mode 100644 index 000000000..a21bb5381 --- /dev/null +++ b/src/ccl_app_api_kvs_attr.cpp @@ -0,0 +1,77 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/kvs_attr_ids.hpp" +#include "oneapi/ccl/kvs_attr_ids_traits.hpp" +#include "oneapi/ccl/kvs_attr.hpp" + +// Core file with PIMPL implementation +#include "atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/kvs_common_attr.hpp" +#include "kvs_attr_impl.hpp" + +namespace ccl { + +namespace v1 { + +#define API_FORCE_INSTANTIATION(class_name, IN_attrId, IN_Value, OUT_Traits_Value) \ + template CCL_API IN_Value class_name::set(const IN_Value& v); \ +\ + template CCL_API const typename OUT_Traits_Value::type& \ + class_name::get() const; \ +\ + template CCL_API bool class_name::is_valid() const noexcept; + +/** + * kvs_attr attributes definition + */ +CCL_API kvs_attr::kvs_attr(ccl_empty_attr) + : base_t(impl_value_t(new impl_t(ccl_empty_attr::version))) {} +CCL_API kvs_attr::kvs_attr(kvs_attr&& src) : base_t(std::move(src)) {} + +CCL_API kvs_attr::kvs_attr(const kvs_attr& src) : base_t(src) {} + +CCL_API kvs_attr::kvs_attr( + const typename detail::ccl_api_type_attr_traits::return_type& + version) + : base_t(impl_value_t(new impl_t(version))) {} + +CCL_API kvs_attr::~kvs_attr() noexcept {} + +CCL_API kvs_attr& kvs_attr::operator=(const kvs_attr& src) { + this->get_impl() = src.get_impl(); + return *this; +} + +CCL_API kvs_attr& kvs_attr::operator=(kvs_attr&& src) { + if (src.get_impl() != this->get_impl()) { + src.get_impl().swap(this->get_impl()); + src.get_impl().reset(); + } + return *this; +} + +API_FORCE_INSTANTIATION(kvs_attr, + kvs_attr_id::version, + ccl::library_version, + detail::ccl_api_type_attr_traits) + +#undef API_FORCE_INSTANTIATION + +} // namespace v1 + +} // namespace ccl diff --git a/src/ccl_cpp_api.cpp b/src/ccl_cpp_api.cpp deleted file mode 100644 index 813d69471..000000000 --- a/src/ccl_cpp_api.cpp +++ /dev/null @@ -1,229 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#if 0 -#include "oneapi/ccl.hpp" - -#include "coll/coll_attributes.hpp" - -#include "common/comm/comm_split_common_attr.hpp" -#include "comm_split_attr_impl.hpp" - -#include "common/comm/l0/comm_context_storage.hpp" - -#include "common/event/event_internal/event_internal_impl.hpp" -#include "stream_impl.hpp" - -#include "common/global/global.hpp" -#include "common/comm/comm.hpp" - -#include "common/comm/l0/comm_context.hpp" -#include "oneapi/ccl/ccl_communicator.hpp" - -#include "common/global/global.hpp" -#include "exec/exec.hpp" - -#include "common/comm/comm_interface.hpp" - -#include "oneapi/ccl/native_device_api/export_api.hpp" - -#ifdef CCL_ENABLE_SYCL -#include -#endif - -#define CCL_CHECK_AND_THROW(result, diagnostic) \ - do { \ - if (result != ccl_status_success) { \ - throw ccl::exception(diagnostic); \ - } \ - } while (0); - - -namespace ccl -{ - -CCL_API ccl::environment::environment() -{ - static auto result = global_data::get().init(); - CCL_CHECK_AND_THROW(result, "failed to initialize CCL"); -} - -CCL_API ccl::environment::~environment() -{} - -CCL_API ccl::environment& ccl::environment::instance() -{ - static ccl::environment env; - return env; -} - -void CCL_API ccl::environment::set_resize_fn(ccl_resize_fn_t callback) -{ - ccl_status_t result = ccl_set_resize_fn(callback); - CCL_CHECK_AND_THROW(result, "failed to set resize callback"); - return; -} - -ccl::library_version CCL_API ccl::environment::get_version() const -{ - ccl::library_version ret; - ccl_status_t result = ccl_get_version(&ret); - CCL_CHECK_AND_THROW(result, "failed to get version"); - return ret; -} -/* -static ccl::stream& get_empty_stream() -{ - static ccl::stream_t empty_stream = ccl::environment::instance().create_stream(); - return empty_stream; -} -*/ - -/** - * Factory methods - */ -// KVS -kvs_t CCL_API environment::create_main_kvs() const -{ - return std::shared_ptr(new kvs); -} - -kvs_t CCL_API environment::create_kvs(const kvs::addr_t& addr) const -{ - return std::shared_ptr(new kvs(addr)); -} - -//Communicator -communicator CCL_API environment::create_communicator() const -{ - return communicator::create_communicator(); -} - -communicator CCL_API environment::create_communicator(const size_t size, - shared_ptr_class kvs) const -{ - return communicator::create_communicator(size, kvs); -} - -communicator CCL_API environment::create_communicator(const size_t size, - const size_t rank, - shared_ptr_class kvs) const -{ - return communicator::create_communicator(size, rank, kvs); -} - -//Device communicator -#ifdef MULTI_GPU_SUPPORT - -template -comm_split_attr environment::create_comm_split_attr(attr_value_pair_t&&...avps) const -{ - return comm_split_attr::create_comm_split_attr(std::forward(avps)...); -} - -template -vector_class CCL_API environment::create_communicators( - const size_t devices_size, - const vector_class& local_devices, - ContextType& context, - shared_ptr_class kvs) const -{ - return communicator::create_communicators(devices_size, local_devices, context, kvs); -} - -template -vector_class CCL_API environment::create_communicators( - const size_t cluster_devices_size, /*global devics count*/ - const vector_class>& local_rank_device_map, - ContextType& context, - shared_ptr_class kvs) -{ - return communicator::create_communicators(cluster_devices_size, local_rank_device_map, context, kvs); -} - - -template -vector_class CCL_API environment::create_communicators( - const size_t cluster_devices_size, /*global devics count*/ - const map_class& local_rank_device_map, - ContextType& context, - shared_ptr_class kvs) -{ - return communicator::create_communicators(cluster_devices_size, local_rank_device_map, context, kvs); -} - - -//Stream -template -stream CCL_API environment::create_stream(native_stream_type& native_stream) -{ - return stream::create_stream(native_stream); -} - -template -stream CCL_API environment::create_stream(native_stream_type& native_stream, native_context_type& native_ctx) -{ - return stream::create_stream(native_stream, native_ctx); -} - -template -stream CCL_API environment::create_stream_from_attr(typename unified_device_type::ccl_native_t device, attr_value_pair_t&&...avps) -{ - return stream::create_stream_from_attr(device, std::forward(avps)...); -} - -template -stream CCL_API environment::create_stream_from_attr(typename unified_device_type::ccl_native_t device, - typename unified_device_context_type::ccl_native_t context, - attr_value_pair_t&&...avps) -{ - return stream::create_stream_from_attr(device, context, std::forward(avps)...); -} - - -//Event -template -event CCL_API environment::create_event(event_type& native_event) -{ - return event::create_event(native_event); -} - -template -event CCL_API environment::create_event_from_attr(event_type& native_event_handle, - typename unified_device_context_type::ccl_native_t context, - attr_value_pair_t&&...avps) -{ - return event::create_event_from_attr(native_event_handle, context, std::forward(avps)...); -} -/* -#define STREAM_CREATOR_INSTANTIATION(type) \ -template ccl::stream_t CCL_API ccl::environment::create_stream(type& stream); - -#ifdef CCL_ENABLE_SYCL -STREAM_CREATOR_INSTANTIATION(cl::sycl::queue) -#endif -*/ -#endif //MULTI_GPU_SUPPORT -} -#include "types_generator_defines.hpp" -#include "oneapi/ccl/ccl_cpp_api_explicit_in.hpp" -#endif //0 diff --git a/src/ccl_cpp_api_explicit_in.hpp b/src/ccl_cpp_api_explicit_in.hpp deleted file mode 100644 index 9a36f6393..000000000 --- a/src/ccl_cpp_api_explicit_in.hpp +++ /dev/null @@ -1,88 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#ifndef COMMA -#define COMMA , -#endif - -//TODO -#if 0 -/** - * Attributes - */ -HOST_ATTRIBUTE_INSTANTIATION(ccl_host_color, - typename ccl::comm_split_attr_id_traits::type); -HOST_ATTRIBUTE_INSTANTIATION(ccl_host_version, - typename ccl::comm_split_attr_id_traits::type); - -API_COLL_EXPLICIT_INSTANTIATION(char); -API_COLL_EXPLICIT_INSTANTIATION(int); -API_COLL_EXPLICIT_INSTANTIATION(int64_t); -API_COLL_EXPLICIT_INSTANTIATION(uint64_t); -API_COLL_EXPLICIT_INSTANTIATION(float); -API_COLL_EXPLICIT_INSTANTIATION(double); - -#ifdef CCL_ENABLE_SYCL - API_COLL_EXPLICIT_CLASS_INSTANTIATION(cl::sycl::buffer); - API_COLL_EXPLICIT_CLASS_INSTANTIATION(cl::sycl::buffer); - API_COLL_EXPLICIT_CLASS_INSTANTIATION(cl::sycl::buffer); - API_COLL_EXPLICIT_CLASS_INSTANTIATION(cl::sycl::buffer); - API_COLL_EXPLICIT_CLASS_INSTANTIATION(cl::sycl::buffer); - API_COLL_EXPLICIT_CLASS_INSTANTIATION(cl::sycl::buffer); -#endif //CCL_ENABLE_SYCL - -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(char, char); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(char, int); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(char, ccl::bf16); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(char, float); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(char, double); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(char, int64_t); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(char, uint64_t); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int, char); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int, int); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int, ccl::bf16); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int, float); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int, double); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int, int64_t); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int, uint64_t); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int64_t, char); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int64_t, int); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int64_t, ccl::bf16); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int64_t, float); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int64_t, double); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int64_t, int64_t); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(int64_t, uint64_t); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(uint64_t, char); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(uint64_t, int); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(uint64_t, ccl::bf16); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(uint64_t, float); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(uint64_t, double); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(uint64_t, int64_t); -API_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(uint64_t, uint64_t); - -#ifdef CCL_ENABLE_SYCL - API_SPARSE_ALLREDUCE_EXPLICIT_CLASS_INSTANTIATION(cl::sycl::buffer, - cl::sycl::buffer); - API_SPARSE_ALLREDUCE_EXPLICIT_CLASS_INSTANTIATION(cl::sycl::buffer, - cl::sycl::buffer); - - API_SPARSE_ALLREDUCE_EXPLICIT_CLASS_INSTANTIATION(cl::sycl::buffer, - cl::sycl::buffer); - API_SPARSE_ALLREDUCE_EXPLICIT_CLASS_INSTANTIATION(cl::sycl::buffer, - cl::sycl::buffer); -#endif //CCL_ENABLE_SYCL -#undef COMMA - -#endif //TODO diff --git a/src/ccl_cpp_communicator.cpp b/src/ccl_cpp_communicator.cpp index 5a1743776..7e608b59b 100644 --- a/src/ccl_cpp_communicator.cpp +++ b/src/ccl_cpp_communicator.cpp @@ -13,31 +13,39 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_aliases.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" -#include "oneapi/ccl/ccl_type_traits.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" +#include "oneapi/ccl/type_traits.hpp" +#include "oneapi/ccl/types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_coll_attr.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" +#include "oneapi/ccl/coll_attr.hpp" -#include "oneapi/ccl/ccl_comm_split_attr_ids.hpp" -#include "oneapi/ccl/ccl_comm_split_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_comm_split_attr.hpp" +#include "oneapi/ccl/comm_attr_ids.hpp" +#include "oneapi/ccl/comm_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_attr.hpp" -#include "common/event/event_internal/event_internal_attr_ids.hpp" -#include "common/event/event_internal/event_internal_attr_ids_traits.hpp" -#include "common/event/event_internal/event_internal.hpp" +#include "oneapi/ccl/comm_split_attr_ids.hpp" +#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_split_attr.hpp" -#include "oneapi/ccl/ccl_stream_attr_ids.hpp" -#include "oneapi/ccl/ccl_stream_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_stream.hpp" +#include "oneapi/ccl/stream_attr_ids.hpp" +#include "oneapi/ccl/stream_attr_ids_traits.hpp" +#include "oneapi/ccl/stream.hpp" -#include "oneapi/ccl/ccl_event.hpp" +#include "oneapi/ccl/device_attr_ids.hpp" +#include "oneapi/ccl/device_attr_ids_traits.hpp" +#include "oneapi/ccl/device.hpp" -#include "oneapi/ccl/ccl_communicator.hpp" +#include "oneapi/ccl/context_attr_ids.hpp" +#include "oneapi/ccl/context_attr_ids_traits.hpp" +#include "oneapi/ccl/context.hpp" + +#include "oneapi/ccl/event.hpp" + +#include "oneapi/ccl/communicator.hpp" #include "common/comm/l0/comm_context_storage.hpp" #include "common/global/global.hpp" @@ -50,10 +58,11 @@ namespace ccl { +namespace v1 { + CCL_API communicator::communicator(impl_value_t&& impl) : base_t(std::move(impl)) {} -CCL_API communicator::communicator(communicator&& src) - : base_t(std::move(src)) {} +CCL_API communicator::communicator(communicator&& src) : base_t(std::move(src)) {} CCL_API communicator& communicator::operator=(communicator&& src) { if (src.get_impl() != this->get_impl()) { @@ -65,11 +74,11 @@ CCL_API communicator& communicator::operator=(communicator&& src) { CCL_API communicator::~communicator() {} -CCL_API size_t communicator::rank() const { +CCL_API int communicator::rank() const { return get_impl()->rank(); } -CCL_API size_t communicator::size() const { +CCL_API int communicator::size() const { return get_impl()->size(); } @@ -82,32 +91,37 @@ CCL_API communicator communicator::split(const comm_split_attr& attr) { return communicator(get_impl()->split(attr)); } -CCL_API communicator::ccl_device_t communicator::get_device() { - return get_impl()->get_device(); +CCL_API device communicator::get_device() const { + return device::create_device(get_impl()->get_device()); } -CCL_API communicator::ccl_context_t communicator::get_context() { - return get_impl()->get_context(); +CCL_API context communicator::get_context() const { + return context::create_context(get_impl()->get_context()); } +} // namespace v1 + } // namespace ccl /****API force instantiations for factory methods******/ -API_DEVICE_COMM_CREATE_WO_RANK_EXPLICIT_INSTANTIATION(ccl::device, ccl::context) -API_DEVICE_COMM_CREATE_WITH_RANK_IN_VECTOR_EXPLICIT_INSTANTIATION(ccl::device, - ccl::context) -API_DEVICE_COMM_CREATE_WITH_RANK_IN_MAP_EXPLICIT_INSTANTIATION(ccl::device, ccl::context) - -API_DEVICE_COMM_CREATE_WO_RANK_EXPLICIT_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, typename ccl::unified_device_context_type::ccl_native_t) -API_DEVICE_COMM_CREATE_WITH_RANK_IN_VECTOR_EXPLICIT_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, typename ccl::unified_device_context_type::ccl_native_t) -API_DEVICE_COMM_CREATE_WITH_RANK_IN_MAP_EXPLICIT_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, typename ccl::unified_device_context_type::ccl_native_t) - -API_DEVICE_COMM_CREATE_WO_RANK_EXPLICIT_INSTANTIATION( - ccl::device_index_type, - typename ccl::unified_device_context_type::ccl_native_t) -API_DEVICE_COMM_CREATE_WITH_RANK_IN_VECTOR_EXPLICIT_INSTANTIATION( +API_COMM_CREATE_WO_RANK_EXPLICIT_INSTANTIATION(ccl::device, ccl::context) +API_COMM_CREATE_WITH_RANK_IN_VECTOR_EXPLICIT_INSTANTIATION(ccl::device, ccl::context) +API_COMM_CREATE_WITH_RANK_IN_MAP_EXPLICIT_INSTANTIATION(ccl::device, ccl::context) + +API_COMM_CREATE_WO_RANK_EXPLICIT_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, + typename ccl::unified_context_type::ccl_native_t) +API_COMM_CREATE_WITH_RANK_IN_VECTOR_EXPLICIT_INSTANTIATION( + typename ccl::unified_device_type::ccl_native_t, + typename ccl::unified_context_type::ccl_native_t) +API_COMM_CREATE_WITH_RANK_IN_MAP_EXPLICIT_INSTANTIATION( + typename ccl::unified_device_type::ccl_native_t, + typename ccl::unified_context_type::ccl_native_t) + +API_COMM_CREATE_WO_RANK_EXPLICIT_INSTANTIATION(ccl::device_index_type, + typename ccl::unified_context_type::ccl_native_t) +API_COMM_CREATE_WITH_RANK_IN_VECTOR_EXPLICIT_INSTANTIATION( ccl::device_index_type, - typename ccl::unified_device_context_type::ccl_native_t) -API_DEVICE_COMM_CREATE_WITH_RANK_IN_MAP_EXPLICIT_INSTANTIATION( + typename ccl::unified_context_type::ccl_native_t) +API_COMM_CREATE_WITH_RANK_IN_MAP_EXPLICIT_INSTANTIATION( ccl::device_index_type, - typename ccl::unified_device_context_type::ccl_native_t) + typename ccl::unified_context_type::ccl_native_t) diff --git a/src/ccl_cpp_context.cpp b/src/ccl_cpp_context.cpp index ff7e409d3..7db9a742b 100644 --- a/src/ccl_cpp_context.cpp +++ b/src/ccl_cpp_context.cpp @@ -13,11 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" +#include "oneapi/ccl/types.hpp" #include "context_impl.hpp" namespace ccl { +namespace v1 { + CCL_API context::context(context&& src) : base_t(std::move(src)) {} CCL_API context::context(const context& src) : base_t(src) {} @@ -40,25 +42,37 @@ CCL_API context& context::operator=(const context& src) { return *this; } +bool CCL_API context::operator==(const context& rhs) const noexcept { + return this->get_impl() == rhs.get_impl(); +} + +bool CCL_API context::operator!=(const context& rhs) const noexcept { + return this->get_impl() != rhs.get_impl(); +} + +bool CCL_API context::operator<(const context& rhs) const noexcept { + return this->get_impl() < rhs.get_impl(); +} + CCL_API void context::build_from_params() { get_impl()->build_from_params(); } -CCL_API context::native_t& context::get_native() -{ +CCL_API context::native_t& context::get_native() { return const_cast(static_cast(this)->get_native()); } -CCL_API const context::native_t& context::get_native() const -{ +CCL_API const context::native_t& context::get_native() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } -} // namespace ccl +} // namespace v1 + +} // namespace ccl -API_DEVICE_CONTEXT_CREATION_FORCE_INSTANTIATION(typename ccl::unified_device_context_type::ccl_native_t) +API_CONTEXT_CREATION_FORCE_INSTANTIATION(typename ccl::unified_context_type::ccl_native_t) -API_DEVICE_CONTEXT_FORCE_INSTANTIATION(ccl::context_attr_id::version, ccl::library_version); -API_DEVICE_CONTEXT_FORCE_INSTANTIATION_GET(ccl::context_attr_id::cl_backend); -API_DEVICE_CONTEXT_FORCE_INSTANTIATION_GET(ccl::context_attr_id::native_handle); +API_CONTEXT_FORCE_INSTANTIATION(ccl::context_attr_id::version, ccl::library_version); +API_CONTEXT_FORCE_INSTANTIATION_GET(ccl::context_attr_id::cl_backend); +API_CONTEXT_FORCE_INSTANTIATION_GET(ccl::context_attr_id::native_handle); diff --git a/src/ccl_cpp_device.cpp b/src/ccl_cpp_device.cpp index 68e9c4003..826fcabcc 100644 --- a/src/ccl_cpp_device.cpp +++ b/src/ccl_cpp_device.cpp @@ -13,11 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" +#include "oneapi/ccl/types.hpp" #include "device_impl.hpp" namespace ccl { +namespace v1 { + CCL_API device::device(device&& src) : base_t(std::move(src)) {} CCL_API device::device(const device& src) : base_t(src) {} @@ -41,20 +43,33 @@ CCL_API device& device::operator=(const device& src) { return *this; } +bool CCL_API device::operator==(const device& rhs) const noexcept { + return this->get_impl() == rhs.get_impl(); +} + +bool CCL_API device::operator!=(const device& rhs) const noexcept { + return this->get_impl() != rhs.get_impl(); +} + +bool CCL_API device::operator<(const device& rhs) const noexcept { + return this->get_impl() < rhs.get_impl(); +} + CCL_API void device::build_from_params() { get_impl()->build_from_params(); } -CCL_API device::native_t& device::get_native() -{ +CCL_API device::native_t& device::get_native() { return const_cast(static_cast(this)->get_native()); } -CCL_API const device::native_t& device::get_native() const -{ +CCL_API const device::native_t& device::get_native() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } + +} // namespace v1 + } // namespace ccl API_DEVICE_CREATION_FORCE_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t) diff --git a/src/ccl_cpp_environment.cpp b/src/ccl_cpp_environment.cpp index 1fc00ccd0..a31010b4a 100644 --- a/src/ccl_cpp_environment.cpp +++ b/src/ccl_cpp_environment.cpp @@ -16,85 +16,65 @@ #include "environment_impl.hpp" #include "common/global/global.hpp" #include "exec/exec.hpp" +#include "common/utils/version.hpp" #if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) #include "common/comm/l0/comm_context.hpp" #include "common/comm/comm_interface.hpp" #endif //#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -//#include "ccl.h" //TODO datatypes - #include #include "common/comm/single_device_communicator/single_device_communicator.hpp" namespace ccl { -CCL_API ccl::environment::environment() { + +namespace detail { + +CCL_API environment::environment() { static auto result = global_data::get().init(); CCL_CHECK_AND_THROW(result, "failed to initialize CCL"); } -CCL_API ccl::environment::~environment() {} +CCL_API environment::~environment() {} -CCL_API ccl::environment& ccl::environment::instance() { - static ccl::environment env; +CCL_API environment& environment::instance() { + static environment env; return env; } -// void CCL_API ccl::environment::set_resize_fn(ccl_resize_fn_t callback) -// { -// ccl_status_t result = ccl_set_resize_fn(callback); -// CCL_CHECK_AND_THROW(result, "failed to set resize callback"); -// return; -// } - -ccl::library_version CCL_API ccl::environment::get_library_version() const { - ccl::library_version ret; +ccl::library_version CCL_API environment::get_library_version() { + return utils::get_library_version(); +} - ret.major = CCL_MAJOR_VERSION; - ret.minor = CCL_MINOR_VERSION; - ret.update = CCL_UPDATE_VERSION; - ret.product_status = CCL_PRODUCT_STATUS; - ret.build_date = CCL_PRODUCT_BUILD_DATE; - ret.full = CCL_PRODUCT_FULL; +/******************** KVS ********************/ - return ret; -} -/* -static ccl::stream& get_empty_stream() -{ - static ccl::stream_t empty_stream = ccl::environment::instance().create_stream(); - return empty_stream; +shared_ptr_class environment::create_main_kvs(const kvs_attr& attr) const { + return std::shared_ptr(new kvs(attr)); } -*/ -/** - * Factory methods - */ -// KVS -shared_ptr_class CCL_API environment::create_main_kvs() const { - return std::shared_ptr(new kvs); +shared_ptr_class environment::create_kvs(const kvs::address_type& addr, + const kvs_attr& attr) const { + return std::shared_ptr(new kvs(addr, attr)); } -shared_ptr_class CCL_API environment::create_kvs(const kvs::address_type& addr) const { - return std::shared_ptr(new kvs(addr)); -} +/******************** DEVICE ********************/ -// device -device CCL_API environment::create_device(empty_t empty) const -{ +device environment::create_device(empty_t empty) const { static typename ccl::unified_device_type::ccl_native_t default_native_device; return device::create_device(default_native_device); } -// context -context CCL_API environment::create_context(empty_t empty) const -{ - static typename ccl::unified_device_context_type::ccl_native_t default_native_context; +/******************** CONTEXT ********************/ + +context environment::create_context(empty_t empty) const { + static typename ccl::unified_context_type::ccl_native_t default_native_context; return context::create_context(default_native_context); } -ccl::datatype CCL_API environment::register_datatype(const ccl::datatype_attr& attr) { +/******************** DATATYPE ********************/ + +ccl::datatype environment::register_datatype(const datatype_attr& attr) { while (unlikely(ccl::global_data::get().executor->is_locked)) { std::this_thread::yield(); } @@ -104,7 +84,7 @@ ccl::datatype CCL_API environment::register_datatype(const ccl::datatype_attr& a return ccl::global_data::get().dtypes->create(attr); } -void CCL_API environment::deregister_datatype(ccl::datatype dtype) { +void environment::deregister_datatype(ccl::datatype dtype) { while (unlikely(ccl::global_data::get().executor->is_locked)) { std::this_thread::yield(); } @@ -114,7 +94,7 @@ void CCL_API environment::deregister_datatype(ccl::datatype dtype) { ccl::global_data::get().dtypes->free(dtype); } -size_t CCL_API environment::get_datatype_size(ccl::datatype dtype) const { +size_t environment::get_datatype_size(ccl::datatype dtype) const { while (unlikely(ccl::global_data::get().executor->is_locked)) { std::this_thread::yield(); } @@ -122,109 +102,78 @@ size_t CCL_API environment::get_datatype_size(ccl::datatype dtype) const { return ccl::global_data::get().dtypes->get(dtype).size(); } -} // namespace ccl +/******************** STREAM ********************/ + +stream CCL_API environment::create_stream(typename unified_device_type::ccl_native_t device) { + auto version = utils::get_library_version(); + return stream{ stream_provider_dispatcher::create(device, version) }; +} + +stream CCL_API environment::create_stream(typename unified_device_type::ccl_native_t device, + typename unified_context_type::ccl_native_t context) { + auto version = utils::get_library_version(); + return stream{ stream_provider_dispatcher::create(device, context, version) }; +} + +/******************** COMMUNICATOR ********************/ #ifdef CCL_ENABLE_SYCL -ccl::communicator CCL_API ccl::environment::create_single_device_communicator( - const size_t comm_size, - const size_t rank, +communicator environment::create_single_device_communicator( + const int comm_size, + const int rank, const cl::sycl::device& device, const cl::sycl::context& context, - ccl::shared_ptr_class kvs) const { + ccl::shared_ptr_class kvs) const { LOG_TRACE("Create single device communicator from SYCL device"); std::shared_ptr kvs_wrapper(new users_kvs(kvs)); std::shared_ptr atl = std::shared_ptr(new atl_wrapper(comm_size, { rank }, kvs_wrapper)); - ccl::comm_split_attr attr = create_comm_split_attr( - ccl::attr_val(ccl::group_split_type::undetermined)); - ccl::communicator_interface_ptr impl = - ccl::communicator_interface::create_communicator_impl(device, context, rank, comm_size, attr, atl); + comm_split_attr attr = create_comm_split_attr(attr_val( + split_group::cluster /*group_split_type::undetermined*/)); + ccl::communicator_interface_ptr impl = ccl::communicator_interface::create_communicator_impl( + device, context, rank, comm_size, attr, atl); //TODO use gpu_comm_attr to automatically visit() auto single_dev_comm = std::dynamic_pointer_cast(impl); //single_dev_comm->set_context(context); - return ccl::communicator(std::move(impl)); + return communicator(std::move(impl)); } - #endif -//Communicator -ccl::communicator CCL_API ccl::environment::create_communicator() const { - return ccl::communicator::create_communicator(); +communicator environment::create_communicator(const comm_attr& attr) const { + return communicator::create_communicator(attr); } -ccl::communicator CCL_API ccl::environment::create_communicator(const size_t size, - ccl::shared_ptr_class kvs) const { - return ccl::communicator::create_communicator(size, kvs); +communicator environment::create_communicator(const size_t size, + ccl::shared_ptr_class kvs, + const comm_attr& attr) const { + return communicator::create_communicator(size, kvs, attr); } -ccl::communicator CCL_API ccl::environment::create_communicator(const size_t size, - const size_t rank, - ccl::shared_ptr_class kvs) const { - return ccl::communicator::create_communicator(size, rank, kvs); +communicator environment::create_communicator(const size_t size, + const int rank, + ccl::shared_ptr_class kvs, + const comm_attr& attr) const { + return communicator::create_communicator(size, rank, kvs, attr); } -/***************************TypeGenerations*********************************************************/ -namespace ccl { -template <> -stream CCL_API environment::create_postponed_api_type< - stream, - typename unified_device_type::ccl_native_t, - typename unified_device_context_type::ccl_native_t>( - typename unified_device_type::ccl_native_t device, - typename unified_device_context_type::ccl_native_t context) const { - library_version ret{}; - ret.major = CCL_MAJOR_VERSION; - ret.minor = CCL_MINOR_VERSION; - ret.update = CCL_UPDATE_VERSION; - ret.product_status = CCL_PRODUCT_STATUS; - ret.build_date = CCL_PRODUCT_BUILD_DATE; - ret.full = CCL_PRODUCT_FULL; - - return stream{ stream_provider_dispatcher::create(device, context, ret) }; -} -template <> -stream CCL_API -environment::create_postponed_api_type( - typename unified_device_type::ccl_native_t device) const { - library_version ret{}; - ret.major = CCL_MAJOR_VERSION; - ret.minor = CCL_MINOR_VERSION; - ret.update = CCL_UPDATE_VERSION; - ret.product_status = CCL_PRODUCT_STATUS; - ret.build_date = CCL_PRODUCT_BUILD_DATE; - ret.full = CCL_PRODUCT_FULL; - - return stream{ stream_provider_dispatcher::create(device, ret) }; -} -} -CREATE_OP_ATTR_INSTANTIATION(ccl::allgatherv_attr) -CREATE_OP_ATTR_INSTANTIATION(ccl::allreduce_attr) -CREATE_OP_ATTR_INSTANTIATION(ccl::alltoall_attr) -CREATE_OP_ATTR_INSTANTIATION(ccl::alltoallv_attr) -CREATE_OP_ATTR_INSTANTIATION(ccl::broadcast_attr) -CREATE_OP_ATTR_INSTANTIATION(ccl::reduce_attr) -CREATE_OP_ATTR_INSTANTIATION(ccl::reduce_scatter_attr) -CREATE_OP_ATTR_INSTANTIATION(ccl::sparse_allreduce_attr) - -CREATE_OP_ATTR_INSTANTIATION(ccl::comm_split_attr) - -CREATE_OP_ATTR_INSTANTIATION(ccl::datatype_attr) +} // namespace detail + +} // namespace ccl + +/******************** TypeGenerations ********************/ CREATE_DEV_COMM_INSTANTIATION(ccl::device, ccl::context) -CREATE_DEV_COMM_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, typename ccl::unified_device_context_type::ccl_native_t) -CREATE_DEV_COMM_INSTANTIATION(ccl::device_index_type, typename ccl::unified_device_context_type::ccl_native_t) +CREATE_DEV_COMM_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, + typename ccl::unified_context_type::ccl_native_t) +CREATE_DEV_COMM_INSTANTIATION(ccl::device_index_type, + typename ccl::unified_context_type::ccl_native_t) CREATE_STREAM_INSTANTIATION(typename ccl::unified_stream_type::ccl_native_t) -CREATE_STREAM_EXT_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, typename ccl::unified_device_context_type::ccl_native_t) +CREATE_STREAM_EXT_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, + typename ccl::unified_context_type::ccl_native_t) -CREATE_CONTEXT_INSTANTIATION(typename ccl::unified_device_context_type::ccl_native_t) +CREATE_CONTEXT_INSTANTIATION(typename ccl::unified_context_type::ccl_native_t) CREATE_DEVICE_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t) - -/* -CREATE_EVENT_INSTANTIATION(cl::sycl::event) -CREATE_EVENT_EXT_INSTANTIATION(cl_event) -*/ diff --git a/src/ccl_cpp_gpu_api.cpp b/src/ccl_cpp_gpu_api.cpp deleted file mode 100644 index ffc7f2829..000000000 --- a/src/ccl_cpp_gpu_api.cpp +++ /dev/null @@ -1,256 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -//TODO -#if 0 -#include - -#include "oneapi/ccl.hpp" -#include "oneapi/ccl/ccl_type_traits.hpp" -#include "common/global/global.hpp" -#include "exec/exec.hpp" - -#include "common/comm/comm_interface.hpp" -#include "common/comm/host_communicator/host_communicator.hpp" -#include "common/comm/l0/gpu_comm_attr.hpp" -#include "common/comm/l0/device_community.hpp" - -#include "oneapi/ccl/native_device_api/export_api.hpp" -#include "oneapi/ccl/native_device_api/compiler_ccl_wrappers_dispatcher.hpp" - -#ifdef CCL_ENABLE_SYCL -#include -#endif - -std::ostream& operator<<(std::ostream& out, const ccl::device_index_type& index) -{ - out << ccl::to_string(index); - return out; -} - -namespace ccl -{ - - -/* GPU communicator attributes - */ -CCL_API ccl::ccl_device_attr::ccl_device_attr(const ccl::ccl_comm_split_attr& src) : - base_t(src), - pimpl(new ccl::device_attr_impl()) -{ -} - -CCL_API ccl::ccl_device_attr::~ccl_device_attr() noexcept -{ -} - -template -CCL_API Value ccl::ccl_device_attr::set_value(Value&& v) -{ - return pimpl->set_attribute_value(std::forward(v)); -} - -template -CCL_API const typename ccl::ccl_device_attributes_traits::type& ccl::ccl_device_attr::get_value() const -{ - return pimpl->get_attribute_value( - std::integral_constant {}); -} - -/* Global Environment*/ -template -CCL_API ccl::stream_t ccl::environment::create_stream(stream_native_type& s) -{ - return ccl::stream_t(new ccl::stream(stream_provider_dispatcher::create(s))); -} - -CCL_API ccl::comm_group_t ccl::environment::create_comm_group(size_t current_device_group_size, size_t process_device_group_size, - ccl::shared_communicator_t parent_comm /* = ccl::shared_communicator_t()*/) -{ - if (!parent_comm) - { - //use global communicator by default - ccl::shared_communicator_t(ccl::environment::instance().create_communicator()).swap(parent_comm); - } - - ccl::comm_group_t group; - { - // register group slot in global context table, based on communicator id - auto host_comm_impl = std::dynamic_pointer_cast(parent_comm->pimpl); - if (!host_comm_impl) - { - throw ccl::exception(std::string(__FUNCTION__) + " - failed, invalid host communicator type"); - } - - group_context::group_unique_key unique_id = - host_comm_impl->get_host_attr()->get_value(); - - std::unique_lock lock(global_ctx.mutex); - auto ctx_it = global_ctx.communicator_group_map.find(unique_id); - if(ctx_it == global_ctx.communicator_group_map.end()) - { - group.reset(new ccl::comm_group(parent_comm, - current_device_group_size, - process_device_group_size)); - global_ctx.communicator_group_map.insert({ - unique_id, - group - }); - } - else - { - group = ctx_it->second; - } - } - - // sync existing group: blocking operation - wait for all groups - group->pimpl->sync_group_size(current_device_group_size); - return group; -} - -CCL_API ccl::comm_group::comm_group(ccl::shared_communicator_t parent_comm, - size_t current_device_group_size, size_t process_device_group_size): - pimpl(new ccl::gpu_comm_attr(parent_comm, current_device_group_size, process_device_group_size)) -{ -}; - -/** - * Create communicator API: - */ -CCL_API ccl::comm_split_attr ccl::comm_group::create_comm_split_attr() -{ - // TODO - const auto& host_comm = pimpl->get_host_communicator(); - return ccl::comm_split_attr{new ccl::ccl_device_attr(*(host_comm->get_comm_split_attr()))}; -} -/* - * Single device communicator creation - */ -template ::type>::value, - int>::type> -CCL_API ccl::communicator_t ccl::comm_group::create_communicator(const DeviceType& device, - ccl::comm_split_attr attr/* = comm_device_attr_t()*/) -{ - LOG_TRACE("Create communicator from device"); - ccl::communicator_interface_ptr impl = - ccl::communicator_interface::create_communicator_impl(device, - pimpl->thread_id, - pimpl->ccl_communicator->rank(), - attr, - pimpl->ccl_communicator->comm_impl.atl); - // registering device in group - is non blocking operation, until it is not the last device - pimpl->sync_register_communicator(impl); - return ccl::communicator_t(new ccl::communicator(impl)); -} - -template ::type>::value, - int>::type> -CCL_API ccl::communicator_t ccl::comm_group::create_communicator(DeviceType device_id, - ccl::comm_split_attr attr/* = nullptr*/) -{ - LOG_TRACE("Create communicator from id: ", device_id); - - ccl::communicator_interface_ptr impl = ccl::communicator_interface::create_communicator_impl(device_id, - pimpl->thread_id, - pimpl->ccl_communicator->rank(), - attr); - // registering device in group - is non blocking operation, until it is not the last device - pimpl->sync_register_communicator(impl); - return ccl::communicator_t(new ccl::communicator(impl)); -} - -/** - * Multiple device communicators creation vectorized API implementation - */ -template -CCL_API std::vector ccl::comm_group::create_communicators(InputIt first, InputIt last, - ccl::comm_split_attr attr/* = nullptr*/) -{ - - using iterator_value_type = typename std::iterator_traits::value_type; -/* - using expected_value_type = typename unified_device_type::device_t; - static_assert(std::is_same::value, - "Not valid InputIt in create_communicators"); -*/ - size_t indices_count = std::distance(first, last); - LOG_TRACE("Create device communicators from index iterators type, count: ", indices_count); - - std::vector comms; - comms.reserve(indices_count); - std::transform(first, last, std::back_inserter(comms), [this, attr](const iterator_value_type& device_id) - { - return create_communicator(device_id, attr); - }); - return comms; -} - -template class Container, class Type> -CCL_API std::vector ccl::comm_group::create_communicators(const Container& device_ids, - ccl::comm_split_attr attr/* = nullptr*/) -{ - //static_assert(std::is_same::value, "Invalid Type in create_communicators"); - LOG_TRACE("Create device communicators from index type, count: ", device_ids.size(), - ". Redirect to iterators version"); - return create_communicators(device_ids.begin(), device_ids.end(), attr); -} - -CCL_API ccl::comm_group::device_context_native_const_reference_t ccl::comm_group::get_context() const -{ - //TODO use PIMPL as context provider - static unified_device_context_type context; - return context.get(); -} - - -/***********************************************************************/ -#define DEVICE_ATTRIBUTE_INSTANTIATION(ATTR_ID, VALUE_TYPE) \ - template VALUE_TYPE CCL_API ccl::ccl_device_attr::set_value(VALUE_TYPE && \ - v); \ - template CCL_API const VALUE_TYPE& ccl::ccl_device_attr::get_value() const; - -#define STREAM_CREATOR_INSTANTIATION(type) \ - template ccl::stream_t CCL_API ccl::environment::create_stream(type& stream); - -#define COMM_CREATOR_INDEXED_INSTANTIATION_CONTAINER(type) \ - template std::vector CCL_API ccl::comm_group::create_communicators( \ - const type& device_ids, ccl::comm_split_attr attr); - -// device attribute instantiations -DEVICE_ATTRIBUTE_INSTANTIATION(ccl_device_preferred_topology_class, - typename ccl::ccl_device_attributes_traits::type); -DEVICE_ATTRIBUTE_INSTANTIATION(ccl_device_preferred_group, - typename ccl::ccl_device_attributes_traits::type); - - -// stream instantiations -STREAM_CREATOR_INSTANTIATION(ze_command_queue_handle_t) -#ifdef CCL_ENABLE_SYCL - STREAM_CREATOR_INSTANTIATION(cl::sycl::queue) -#endif - -// container-based method force-instantiation will trigger ALL other methods instantiations -COMM_CREATOR_INDEXED_INSTANTIATION_CONTAINER(std::vector); -COMM_CREATOR_INDEXED_INSTANTIATION_CONTAINER(std::list); -COMM_CREATOR_INDEXED_INSTANTIATION_CONTAINER(ccl::device_indices_t); -#ifdef CCL_ENABLE_SYCL - COMM_CREATOR_INDEXED_INSTANTIATION_CONTAINER(cl::sycl::vector_class); -#endif - -#endif //TODO diff --git a/src/ccl_cpp_kvs.cpp b/src/ccl_cpp_kvs.cpp index d84d7a3ec..f18b130b8 100644 --- a/src/ccl_cpp_kvs.cpp +++ b/src/ccl_cpp_kvs.cpp @@ -13,4 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" + +#include "oneapi/ccl/type_traits.hpp" +#include "oneapi/ccl/types_policy.hpp" + +#include "oneapi/ccl/kvs_attr_ids.hpp" +#include "oneapi/ccl/kvs_attr_ids_traits.hpp" +#include "oneapi/ccl/kvs_attr.hpp" + #include "kvs_impl.hpp" diff --git a/src/ccl_cpp_stream.cpp b/src/ccl_cpp_stream.cpp index 1d36237a5..2fc96e2be 100644 --- a/src/ccl_cpp_stream.cpp +++ b/src/ccl_cpp_stream.cpp @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" +#include "oneapi/ccl/types.hpp" #include "stream_impl.hpp" #include "oneapi/ccl/native_device_api/export_api.hpp" -#ifndef COMMA -#define COMMA , -#endif namespace ccl { + +namespace v1 { + CCL_API stream::stream( - const typename details::ccl_api_type_attr_traits::type& + const typename detail::ccl_api_type_attr_traits::type& version) : base_t(impl_value_t()) {} @@ -45,30 +45,25 @@ CCL_API void stream::build_from_params() { get_impl()->build_from_params(); } -CCL_API stream::native_t& stream::get_native() -{ +CCL_API stream::native_t& stream::get_native() { return const_cast(static_cast(this)->get_native()); } -CCL_API const stream::native_t& stream::get_native() const -{ +CCL_API const stream::native_t& stream::get_native() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } + +} // namespace v1 + } // namespace ccl +API_STREAM_CREATION_FORCE_INSTANTIATION(typename ccl::unified_stream_type::ccl_native_t) +API_STREAM_CREATION_EXT_FORCE_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, + typename ccl::unified_context_type::ccl_native_t) #ifdef CCL_ENABLE_SYCL -API_STREAM_CREATION_FORCE_INSTANTIATION(cl::sycl::queue) API_STREAM_CREATION_FORCE_INSTANTIATION(cl_command_queue) -API_STREAM_CREATION_EXT_FORCE_INSTANTIATION(cl::sycl::device, cl::sycl::context) #else -#ifdef MULTI_GPU_SUPPORT -API_STREAM_CREATION_FORCE_INSTANTIATION( - native::cl_base) -API_STREAM_CREATION_FORCE_INSTANTIATION( - ccl::shared_ptr_class>) -API_STREAM_CREATION_FORCE_INSTANTIATION(ccl::shared_ptr_class) -#endif //API_STREAM_CREATION_FORCE_INSTANTIATION(ccl::empty_t) #endif @@ -78,11 +73,9 @@ API_STREAM_FORCE_INSTANTIATION_GET( API_STREAM_FORCE_INSTANTIATION_GET( ccl::stream_attr_id::device); //, typename ccl::unified_device_type::ccl_native_t); API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::context, - typename ccl::unified_device_context_type::ccl_native_t); + typename ccl::unified_context_type::ccl_native_t); API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::ordinal, uint32_t); API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::index, uint32_t); API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::flags, size_t); API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::mode, size_t); API_STREAM_FORCE_INSTANTIATION(ccl::stream_attr_id::priority, size_t); - -#undef COMMA diff --git a/src/ccl_cpp_utils.cpp b/src/ccl_cpp_utils.cpp index 86fc1a450..a9cd45349 100644 --- a/src/ccl_cpp_utils.cpp +++ b/src/ccl_cpp_utils.cpp @@ -15,33 +15,28 @@ */ #include -#include "oneapi/ccl/ccl_config.h" -#include "oneapi/ccl/ccl_types.hpp" +#include "oneapi/ccl/config.h" +#include "oneapi/ccl/lp_types.hpp" +#include "oneapi/ccl/types.hpp" #include "common/utils/enums.hpp" std::ostream& operator<<(std::ostream& out, const ccl::device_index_type& index); namespace ccl { -using datatype_str_enum = - utils::enum_to_str; -CCL_API string_class to_string(const ccl::datatype& dt) { - return datatype_str_enum({ "INT8", - "UINT8", - "INT16", - "UINT16", - "INT32", - "UINT32", - "INT64", - "UINT64", - "FLOAT16", - "FLOAT32", - "FLOAT64", - "BFLOAT16" }) - .choose(dt, "CUSTOM_TYPE"); +std::string to_string(const bfloat16& v) { + std::stringstream ss; + ss << "bf16::data " << v.data; + return ss.str(); } -CCL_API +// std::string to_string(const float16& v) { +// std::stringstream ss; +// ss << "fp16::data " << v.data; +// return ss.str(); +// } + +/* CCL_API */ std::string to_string(const device_index_type& device_id) { std::stringstream ss; ss << "[" << std::get(device_id) << ":" diff --git a/src/ccl_empty_attr.cpp b/src/ccl_empty_attr.cpp index 1d8f6ad86..a535a7439 100644 --- a/src/ccl_empty_attr.cpp +++ b/src/ccl_empty_attr.cpp @@ -13,16 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" +#include "oneapi/ccl/types.hpp" +#include "common/utils/version.hpp" namespace ccl { -ccl::library_version ccl_empty_attr::version{ - CCL_MAJOR_VERSION, CCL_MINOR_VERSION, CCL_UPDATE_VERSION, - CCL_PRODUCT_STATUS, CCL_PRODUCT_BUILD_DATE, CCL_PRODUCT_FULL, -}; + +namespace v1 { + +library_version ccl_empty_attr::version = utils::get_library_version(); template attr ccl_empty_attr::create_empty() { return attr{ ccl_empty_attr::version }; } + +} // namespace v1 + } // namespace ccl diff --git a/src/ccl_empty_coll_attr.cpp b/src/ccl_empty_coll_attr.cpp index 07ffdde88..c0ad30ce2 100644 --- a/src/ccl_empty_coll_attr.cpp +++ b/src/ccl_empty_coll_attr.cpp @@ -13,17 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_aliases.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_type_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/type_traits.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_coll_attr.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" +#include "oneapi/ccl/coll_attr.hpp" namespace ccl { +namespace v1 { + template CCL_API attr ccl_empty_attr::create_empty() { return attr{ ccl_empty_attr::version }; @@ -41,4 +43,6 @@ CCL_API reduce_scatter_attr default_reduce_scatter_attr = CCL_API sparse_allreduce_attr default_sparse_allreduce_attr = ccl_empty_attr::create_empty(); +} // namespace v1 + } // namespace ccl diff --git a/src/ccl_empty_comm_attr.cpp b/src/ccl_empty_comm_attr.cpp new file mode 100644 index 000000000..93418e9f4 --- /dev/null +++ b/src/ccl_empty_comm_attr.cpp @@ -0,0 +1,38 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/type_traits.hpp" + +#include "oneapi/ccl/comm_attr_ids.hpp" +#include "oneapi/ccl/comm_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_attr.hpp" + +namespace ccl { + +namespace v1 { + +template +CCL_API attr ccl_empty_attr::create_empty() { + return attr{ ccl_empty_attr::version }; +} + +CCL_API comm_attr default_comm_attr = ccl_empty_attr::create_empty(); + +} // namespace v1 + +} // namespace ccl diff --git a/src/ccl_empty_comm_split_attr.cpp b/src/ccl_empty_comm_split_attr.cpp new file mode 100644 index 000000000..a113c7ea8 --- /dev/null +++ b/src/ccl_empty_comm_split_attr.cpp @@ -0,0 +1,38 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/type_traits.hpp" + +#include "oneapi/ccl/comm_split_attr_ids.hpp" +#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_split_attr.hpp" + +namespace ccl { + +namespace v1 { + +template +CCL_API attr ccl_empty_attr::create_empty() { + return attr{ ccl_empty_attr::version }; +} + +CCL_API comm_split_attr default_comm_split_attr = ccl_empty_attr::create_empty(); + +} // namespace v1 + +} // namespace ccl diff --git a/src/ccl_empty_init_attr.cpp b/src/ccl_empty_init_attr.cpp new file mode 100644 index 000000000..af5363dc5 --- /dev/null +++ b/src/ccl_empty_init_attr.cpp @@ -0,0 +1,38 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/type_traits.hpp" + +#include "oneapi/ccl/init_attr_ids.hpp" +#include "oneapi/ccl/init_attr_ids_traits.hpp" +#include "oneapi/ccl/init_attr.hpp" + +namespace ccl { + +namespace v1 { + +template +CCL_API attr ccl_empty_attr::create_empty() { + return attr{ ccl_empty_attr::version }; +} + +CCL_API init_attr default_init_attr = ccl_empty_attr::create_empty(); + +} // namespace v1 + +} // namespace ccl diff --git a/include/oneapi/ccl/ccl_comm_split_attr_ids.hpp b/src/ccl_empty_kvs_attr.cpp similarity index 56% rename from include/oneapi/ccl/ccl_comm_split_attr_ids.hpp rename to src/ccl_empty_kvs_attr.cpp index 3dd55e857..cb1fed3ea 100644 --- a/include/oneapi/ccl/ccl_comm_split_attr_ids.hpp +++ b/src/ccl_empty_kvs_attr.cpp @@ -13,34 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -#ifndef CCL_PRODUCT_FULL -#error "Do not include this file directly. Please include 'ccl.hpp'" -#endif - -namespace ccl { - -enum class comm_split_attr_id : int { - version, - - color, - group, - - last_value -}; - -enum class - group_split_type : int { // TODO fill in this enum with the actual values - undetermined = -1, - //device, - thread, - process, - //socket, - //node, - cluster, - - last_value - }; - -} // namespace ccl +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/type_traits.hpp" + +#include "oneapi/ccl/kvs_attr_ids.hpp" +#include "oneapi/ccl/kvs_attr_ids_traits.hpp" +#include "oneapi/ccl/kvs_attr.hpp" + +namespace ccl { + +namespace v1 { + +template +CCL_API attr ccl_empty_attr::create_empty() { + return attr{ ccl_empty_attr::version }; +} + +CCL_API kvs_attr default_kvs_attr = ccl_empty_attr::create_empty(); + +} // namespace v1 + +} // namespace ccl diff --git a/src/ccl_empty_stream.cpp b/src/ccl_empty_stream.cpp index 041a8e779..eb1d993b7 100644 --- a/src/ccl_empty_stream.cpp +++ b/src/ccl_empty_stream.cpp @@ -13,20 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_aliases.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_type_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/type_traits.hpp" -#include "oneapi/ccl/ccl_stream_attr_ids.hpp" -#include "oneapi/ccl/ccl_stream_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_stream.hpp" +#include "oneapi/ccl/stream_attr_ids.hpp" +#include "oneapi/ccl/stream_attr_ids_traits.hpp" +#include "oneapi/ccl/stream.hpp" // Core file with PIMPL implementation //#include "stream_impl.hpp" namespace ccl { +namespace v1 { + template CCL_API attr ccl_empty_attr::create_empty() { return attr{ ccl_empty_attr::version }; @@ -34,4 +36,6 @@ CCL_API attr ccl_empty_attr::create_empty() { CCL_API stream default_stream = ccl_empty_attr::create_empty(); +} // namespace v1 + } // namespace ccl diff --git a/src/ccl_gpu_module.hpp b/src/ccl_gpu_module.hpp index 5e7ce6a12..6a49ecfe1 100644 --- a/src/ccl_gpu_module.hpp +++ b/src/ccl_gpu_module.hpp @@ -15,11 +15,12 @@ */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" +#include "oneapi/ccl/types.hpp" #include "coll/algorithms/algorithms_enum.hpp" +#include "internal_types.hpp" #ifdef MULTI_GPU_SUPPORT -ccl_status_t CCL_API register_gpu_module_source(const char* source, - ccl::device_topology_type topology_class, - ccl_coll_type type); +ccl::status register_gpu_module_source(const char* source, + ccl::device_topology_type topology_class, + ccl_coll_type type); #endif //MULTI_GPU_SUPPORT diff --git a/src/ccl_gpu_modules.cpp b/src/ccl_gpu_modules.cpp index 07c977a7b..88e2376b8 100644 --- a/src/ccl_gpu_modules.cpp +++ b/src/ccl_gpu_modules.cpp @@ -23,9 +23,9 @@ #include "common/comm/l0/device_group_routing_schema.hpp" #include "coll/algorithms/algorithms_enum.hpp" -ccl_status_t CCL_API register_gpu_module_source(const char* path, - ccl::device_topology_type topology_class, - ccl_coll_type type) { +ccl::status register_gpu_module_source(const char* path, + ccl::device_topology_type topology_class, + ccl_coll_type type) { ccl::device_topology_type t_class = static_cast(topology_class); char pwd[PATH_MAX]; char* ret = getcwd(pwd, sizeof(pwd)); @@ -70,17 +70,17 @@ ccl_status_t CCL_API register_gpu_module_source(const char* path, native::specific_modules_source_data_storage::instance() .load_kernel_source(path, t_class); break; - default: - throw std::runtime_error(std::string(__PRETTY_FUNCTION__) + - " - get unexpected ccl collective type: " + - std::to_string(type)); + default: + throw std::runtime_error( + std::string(__PRETTY_FUNCTION__) + + " - get unexpected ccl collective type: " + std::to_string(type)); break; } } catch (const std::exception& ex) { LOG_ERROR("Cannot preload kernel source by path: ", path, ", error: ", ex.what()); CCL_ASSERT(false); - return ccl_status_runtime_error; + return ccl::status::runtime_error; } LOG_INFO("gpu kernel source by type \"", @@ -88,7 +88,7 @@ ccl_status_t CCL_API register_gpu_module_source(const char* path, "\", topology class: \"", to_string(t_class), "\" loaded succesfully"); - return ccl_status_success; + return ccl::status::success; } #endif //MULTI_GPU_SUPPORT diff --git a/src/ccl_utils.cpp b/src/ccl_utils.cpp index 2f723e6b8..411dc9bea 100644 --- a/src/ccl_utils.cpp +++ b/src/ccl_utils.cpp @@ -16,8 +16,7 @@ #include #include -#include "ccl.hpp" -#include "ccl_type_traits.hpp" +#include "oneapi/ccl/types.hpp" std::ostream& operator<<(std::ostream& out, const ccl::device_index_type& index) { out << ccl::to_string(index); @@ -26,7 +25,6 @@ std::ostream& operator<<(std::ostream& out, const ccl::device_index_type& index) namespace ccl { -CCL_API std::string to_string(const device_index_type& device_id) { std::stringstream ss; ss << "[" << std::get(device_id) << ":" @@ -44,7 +42,6 @@ std::string to_string(const device_index_type& device_id) { return ss.str(); } -CCL_API device_index_type from_string(const std::string& device_id_str) { std::string::size_type from_pos = device_id_str.find('['); if (from_pos == std::string::npos) { diff --git a/src/coll/algorithms/algorithms.hpp b/src/coll/algorithms/algorithms.hpp index 5270a2103..57b5e8456 100644 --- a/src/coll/algorithms/algorithms.hpp +++ b/src/coll/algorithms/algorithms.hpp @@ -17,167 +17,88 @@ #include "sched/master_sched.hpp" #include "sched/sched.hpp" +#include "internal_types.hpp" #include #include #define CCL_UNDEFINED_ALGO_ID (-1) -ccl_status_t ccl_coll_build_naive_bcast(ccl_sched* sched, - ccl_buffer buf, - size_t count, - const ccl_datatype& dtype, - size_t root, - ccl_comm* comm); - -ccl_status_t ccl_coll_build_scatter_ring_allgather_bcast(ccl_sched* sched, - ccl_buffer buf, - size_t count, - const ccl_datatype& dtype, - size_t root, - ccl_comm* comm); - -ccl_status_t ccl_coll_build_dissemination_barrier(ccl_sched* sched, ccl_comm* comm); - -ccl_status_t ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - size_t root, - ccl_comm* comm); - -ccl_status_t ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm); - -ccl_status_t ccl_coll_build_binomial_reduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - size_t root, - ccl_comm* comm); - -ccl_status_t ccl_coll_build_ring_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm); +ccl::status ccl_coll_build_naive_bcast(ccl_sched* sched, + ccl_buffer buf, + size_t count, + const ccl_datatype& dtype, + int root, + ccl_comm* comm); -ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm); +ccl::status ccl_coll_build_scatter_ring_allgather_bcast(ccl_sched* sched, + ccl_buffer buf, + size_t count, + const ccl_datatype& dtype, + int root, + ccl_comm* comm); -ccl_status_t ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm); +ccl::status ccl_coll_build_dissemination_barrier(ccl_sched* sched, ccl_comm* comm); -ccl_status_t ccl_coll_build_starlike_allreduce(ccl_sched* sched, +ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, size_t count, const ccl_datatype& dtype, ccl::reduction reduction, + int root, ccl_comm* comm); -ccl_status_t ccl_coll_build_naive_allgatherv(ccl_sched* sched, - ccl_buffer send_buf, - size_t send_count, - ccl_buffer recv_buf, - const size_t* recv_counts, - const ccl_datatype& dtype, - ccl_comm* comm); - -template -ccl_status_t ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, - ccl_buffer send_ind_buf, - size_t send_ind_count, - ccl_buffer send_val_buf, - size_t send_val_count, - void** recv_ind_buf, - size_t* recv_ind_count, - void** recv_val_buf, - size_t* recv_val_count, - const ccl_datatype& index_dtype, - const ccl_datatype& value_dtype, - ccl::reduction reduction, - ccl_comm* comm); - -template -ccl_status_t ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, - ccl_buffer send_ind_buf, - size_t send_ind_count, - ccl_buffer send_val_buf, - size_t send_val_count, - void** recv_ind_buf, - size_t* recv_ind_count, - void** recv_val_buf, - size_t* recv_val_count, - const ccl_datatype& index_dtype, - const ccl_datatype& value_dtype, +ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, ccl::reduction reduction, ccl_comm* comm); -template -ccl_status_t ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, - ccl_buffer send_ind_buf, - size_t send_ind_count, - ccl_buffer send_val_buf, - size_t send_val_count, - void** recv_ind_buf, - size_t* recv_ind_count, - void** recv_val_buf, - size_t* recv_val_count, - const ccl_datatype& index_dtype, - const ccl_datatype& value_dtype, - ccl::reduction reduction, - ccl_comm* comm); - -class ccl_double_tree; -ccl_status_t ccl_coll_build_double_tree_op(ccl_sched* sched, - ccl_coll_type coll_type, +ccl::status ccl_coll_build_binomial_reduce(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, size_t count, const ccl_datatype& dtype, ccl::reduction reduction, - const ccl_double_tree& dtree, + int root, ccl_comm* comm); -ccl_status_t ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t send_count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm); - -ccl_status_t ccl_coll_build_ring_reduce_scatter_block(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t recv_count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm); - -ccl_status_t ccl_coll_build_ring_allgatherv(ccl_sched* sched, +ccl::status ccl_coll_build_ring_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm); + +ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm); + +ccl::status ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm); + +ccl::status ccl_coll_build_starlike_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm); + +ccl::status ccl_coll_build_naive_allgatherv(ccl_sched* sched, ccl_buffer send_buf, size_t send_count, ccl_buffer recv_buf, @@ -185,70 +106,150 @@ ccl_status_t ccl_coll_build_ring_allgatherv(ccl_sched* sched, const ccl_datatype& dtype, ccl_comm* comm); -ccl_status_t ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, - std::vector& scheds, - const ccl_coll_param& coll_param); - -ccl_status_t ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, - std::vector& scheds, - const ccl_coll_param& coll_param); - -ccl_status_t ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sched, - std::vector& scheds, - const ccl_coll_param& coll_param); +template +ccl::status ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, + ccl_buffer send_ind_buf, + size_t send_ind_count, + ccl_buffer send_val_buf, + size_t send_val_count, + void** recv_ind_buf, + size_t* recv_ind_count, + void** recv_val_buf, + size_t* recv_val_count, + const ccl_datatype& index_dtype, + const ccl_datatype& value_dtype, + ccl::reduction reduction, + ccl_comm* comm); -/* direct algorithms - i.e. direct mapping on collective API from transport level */ +template +ccl::status ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, + ccl_buffer send_ind_buf, + size_t send_ind_count, + ccl_buffer send_val_buf, + size_t send_val_count, + void** recv_ind_buf, + size_t* recv_ind_count, + void** recv_val_buf, + size_t* recv_val_count, + const ccl_datatype& index_dtype, + const ccl_datatype& value_dtype, + ccl::reduction reduction, + ccl_comm* comm); -ccl_status_t ccl_coll_build_direct_barrier(ccl_sched* sched, ccl_comm* comm); +template +ccl::status ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, + ccl_buffer send_ind_buf, + size_t send_ind_count, + ccl_buffer send_val_buf, + size_t send_val_count, + void** recv_ind_buf, + size_t* recv_ind_count, + void** recv_val_buf, + size_t* recv_val_count, + const ccl_datatype& index_dtype, + const ccl_datatype& value_dtype, + ccl::reduction reduction, + ccl_comm* comm); -ccl_status_t ccl_coll_build_direct_reduce(ccl_sched* sched, +class ccl_double_tree; +ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, + ccl_coll_type coll_type, ccl_buffer send_buf, ccl_buffer recv_buf, size_t count, const ccl_datatype& dtype, ccl::reduction reduction, - size_t root, + const ccl_double_tree& dtree, ccl_comm* comm); -ccl_status_t ccl_coll_build_direct_allgatherv(ccl_sched* sched, - ccl_buffer send_buf, - size_t send_count, - ccl_buffer recv_buf, - const size_t* recv_counts, - const ccl_datatype& dtype, - ccl_comm* comm); +ccl::status ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t send_count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm); + +ccl::status ccl_coll_build_ring_reduce_scatter_block(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t recv_count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm); + +ccl::status ccl_coll_build_ring_allgatherv(ccl_sched* sched, + ccl_buffer send_buf, + size_t send_count, + ccl_buffer recv_buf, + const size_t* recv_counts, + const ccl_datatype& dtype, + ccl_comm* comm); + +ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, + std::vector& scheds, + const ccl_coll_param& coll_param); + +ccl::status ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, + std::vector& scheds, + const ccl_coll_param& coll_param); + +ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sched, + std::vector& scheds, + const ccl_coll_param& coll_param); + +/* direct algorithms - i.e. direct mapping on collective API from transport level */ -ccl_status_t ccl_coll_build_direct_allreduce(ccl_sched* sched, +ccl::status ccl_coll_build_direct_barrier(ccl_sched* sched, ccl_comm* comm); + +ccl::status ccl_coll_build_direct_reduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + int root, + ccl_comm* comm); + +ccl::status ccl_coll_build_direct_allgatherv(ccl_sched* sched, ccl_buffer send_buf, + size_t send_count, ccl_buffer recv_buf, - size_t count, + const size_t* recv_counts, const ccl_datatype& dtype, - ccl::reduction reduction, ccl_comm* comm); -ccl_status_t ccl_coll_build_direct_alltoall(ccl_sched* sched, +ccl::status ccl_coll_build_direct_allreduce(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, size_t count, const ccl_datatype& dtype, + ccl::reduction reduction, ccl_comm* comm); -ccl_status_t ccl_coll_build_direct_alltoallv(ccl_sched* sched, - ccl_buffer send_buf, - const size_t* send_counts, - ccl_buffer recv_buf, - const size_t* recv_counts, - const ccl_datatype& dtype, - ccl_comm* comm); +ccl::status ccl_coll_build_direct_alltoall(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl_comm* comm); -ccl_status_t ccl_coll_build_direct_bcast(ccl_sched* sched, - ccl_buffer buf, - size_t count, - const ccl_datatype& dtype, - size_t root, - ccl_comm* comm); +ccl::status ccl_coll_build_direct_alltoallv(ccl_sched* sched, + ccl_buffer send_buf, + const size_t* send_counts, + ccl_buffer recv_buf, + const size_t* recv_counts, + const ccl_datatype& dtype, + ccl_comm* comm); + +ccl::status ccl_coll_build_direct_bcast(ccl_sched* sched, + ccl_buffer buf, + size_t count, + const ccl_datatype& dtype, + int root, + ccl_comm* comm); -ccl_status_t ccl_coll_build_direct_reduce_scatter(ccl_sched* sched, +ccl::status ccl_coll_build_direct_reduce_scatter(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, size_t send_count, diff --git a/src/coll/algorithms/allgatherv.cpp b/src/coll/algorithms/allgatherv.cpp index 17f25f3e3..329c50752 100644 --- a/src/coll/algorithms/allgatherv.cpp +++ b/src/coll/algorithms/allgatherv.cpp @@ -17,37 +17,37 @@ #include "sched/entry/factory/chunked_entry_factory.hpp" #include "sched/entry/factory/entry_factory.hpp" -ccl_status_t ccl_coll_build_direct_allgatherv(ccl_sched* sched, - ccl_buffer send_buf, - size_t send_count, - ccl_buffer recv_buf, - const size_t* recv_counts, - const ccl_datatype& dtype, - ccl_comm* comm) { - LOG_DEBUG("build direct allgatherv"); - - entry_factory::make_entry( - sched, send_buf, send_count, recv_buf, recv_counts, dtype, comm); - return ccl_status_success; -} - -ccl_status_t ccl_coll_build_naive_allgatherv(ccl_sched* sched, +ccl::status ccl_coll_build_direct_allgatherv(ccl_sched* sched, ccl_buffer send_buf, size_t send_count, ccl_buffer recv_buf, const size_t* recv_counts, const ccl_datatype& dtype, ccl_comm* comm) { + LOG_DEBUG("build direct allgatherv"); + + entry_factory::make_entry( + sched, send_buf, send_count, recv_buf, recv_counts, dtype, comm); + return ccl::status::success; +} + +ccl::status ccl_coll_build_naive_allgatherv(ccl_sched* sched, + ccl_buffer send_buf, + size_t send_count, + ccl_buffer recv_buf, + const size_t* recv_counts, + const ccl_datatype& dtype, + ccl_comm* comm) { LOG_DEBUG("build naive allgatherv"); - size_t comm_size = comm->size(); - size_t this_rank = comm->rank(); + int comm_size = comm->size(); + int this_rank = comm->rank(); size_t dtype_size = dtype.size(); size_t* offsets = static_cast(CCL_MALLOC(comm_size * sizeof(size_t), "offsets")); - ccl_status_t status = ccl_status_success; + ccl::status status = ccl::status::success; offsets[0] = 0; - for (size_t rank_idx = 1; rank_idx < comm_size; ++rank_idx) { + for (int rank_idx = 1; rank_idx < comm_size; ++rank_idx) { offsets[rank_idx] = offsets[rank_idx - 1] + recv_counts[rank_idx - 1] * dtype_size; } @@ -57,7 +57,7 @@ ccl_status_t ccl_coll_build_naive_allgatherv(ccl_sched* sched, sched, send_buf, recv_buf + offsets[this_rank], send_count, dtype); } - for (size_t rank_idx = 0; rank_idx < comm_size; ++rank_idx) { + for (int rank_idx = 0; rank_idx < comm_size; ++rank_idx) { if (rank_idx != this_rank) { // send own buffer to other ranks entry_factory::make_chunked_send_entry( @@ -72,27 +72,26 @@ ccl_status_t ccl_coll_build_naive_allgatherv(ccl_sched* sched, return status; } -ccl_status_t ccl_coll_build_ring_allgatherv(ccl_sched* sched, - ccl_buffer send_buf, - size_t send_count, - ccl_buffer recv_buf, - const size_t* recv_counts, - const ccl_datatype& dtype, - ccl_comm* comm) { +ccl::status ccl_coll_build_ring_allgatherv(ccl_sched* sched, + ccl_buffer send_buf, + size_t send_count, + ccl_buffer recv_buf, + const size_t* recv_counts, + const ccl_datatype& dtype, + ccl_comm* comm) { LOG_DEBUG("build ring allgatherv, send_count ", send_count); - ccl_status_t status = ccl_status_success; - size_t comm_size, rank; + ccl::status status = ccl::status::success; + int comm_size, rank; size_t dtype_size = dtype.size(); - size_t idx = 0; - size_t src, dst; + int src, dst; comm_size = comm->size(); rank = comm->rank(); size_t* offsets = static_cast(CCL_MALLOC(comm_size * sizeof(size_t), "offsets")); offsets[0] = 0; - for (size_t rank_idx = 1; rank_idx < comm_size; ++rank_idx) { + for (int rank_idx = 1; rank_idx < comm_size; ++rank_idx) { offsets[rank_idx] = offsets[rank_idx - 1] + recv_counts[rank_idx - 1] * dtype_size; } @@ -112,7 +111,7 @@ ccl_status_t ccl_coll_build_ring_allgatherv(ccl_sched* sched, size_t send_block_count, recv_block_count; size_t send_block_offset, recv_block_offset; - for (idx = 0; idx < (comm_size - 1); idx++) { + for (int idx = 0; idx < (comm_size - 1); idx++) { send_block_idx = block_idx; recv_block_idx = (comm_size + block_idx - 1) % comm_size; send_block_count = recv_counts[send_block_idx]; diff --git a/src/coll/algorithms/allreduce/allreduce.cpp b/src/coll/algorithms/allreduce/allreduce.cpp index e620e2390..ab512a617 100644 --- a/src/coll/algorithms/allreduce/allreduce.cpp +++ b/src/coll/algorithms/allreduce/allreduce.cpp @@ -24,30 +24,30 @@ #include "sched/entry/factory/chunked_entry_factory.hpp" #include "sched/entry/factory/entry_factory.hpp" -ccl_status_t ccl_coll_build_direct_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm) { +ccl::status ccl_coll_build_direct_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { LOG_DEBUG("build direct allreduce"); entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype, op, comm); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm) { +ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { LOG_DEBUG("build Rabenseifner's allreduce"); CCL_ASSERT(sched != nullptr, "empty sched"); - ccl_status_t status = ccl_status_success; + ccl::status status = ccl::status::success; int comm_size, rank, newrank, pof2, rem; int i, send_idx, recv_idx, last_idx, mask, newdst, dst, send_cnt, recv_cnt; int *cnts = NULL, *disps = NULL; @@ -269,16 +269,16 @@ ccl_status_t ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, return status; } -ccl_status_t ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm) { +ccl::status ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { LOG_DEBUG("build recursive_doubling allreduce"); - ccl_status_t status = ccl_status_success; + ccl::status status = ccl::status::success; int pof2, rem, comm_size, rank; int newrank, mask, newdst, dst; @@ -378,18 +378,18 @@ ccl_status_t ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, return status; } -ccl_status_t ccl_coll_build_starlike_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm) { +ccl::status ccl_coll_build_starlike_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { LOG_DEBUG("build starlike allreduce"); - ccl_status_t status = ccl_status_success; - size_t comm_size = comm->size(); - size_t this_rank = comm->rank(); + ccl::status status = ccl::status::success; + int comm_size = comm->size(); + int this_rank = comm->rank(); size_t* buffer_counts = static_cast(CCL_MALLOC(comm_size * sizeof(size_t), "buffer_count")); size_t* buffer_offsets = @@ -407,7 +407,7 @@ ccl_status_t ccl_coll_build_starlike_allreduce(ccl_sched* sched, // calculate counts and offsets for each rank size_t common_buffer_count = count / comm_size; - for (size_t rank_idx = 0; rank_idx < comm_size; ++rank_idx) { + for (int rank_idx = 0; rank_idx < comm_size; ++rank_idx) { buffer_counts[rank_idx] = common_buffer_count; buffer_offsets[rank_idx] = rank_idx * buffer_counts[rank_idx] * dtype_size; } @@ -421,7 +421,7 @@ ccl_status_t ccl_coll_build_starlike_allreduce(ccl_sched* sched, tmp_buf = sched->alloc_buffer(this_rank_buf_size * (comm_size - 1)); size_t tmp_buf_recv_idx = 0; - for (size_t rank_idx = 0; rank_idx < comm_size; ++rank_idx) { + for (int rank_idx = 0; rank_idx < comm_size; ++rank_idx) { if (rank_idx != this_rank) { // send buffer to others entry_factory::make_chunked_send_entry(sched, @@ -458,13 +458,13 @@ ccl_status_t ccl_coll_build_starlike_allreduce(ccl_sched* sched, return status; } -ccl_status_t ccl_coll_build_ring_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm) { +ccl::status ccl_coll_build_ring_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { int inplace = (send_buf == recv_buf) ? 1 : 0; LOG_DEBUG("build ring allreduce ", inplace ? "in-place" : "out-of-place"); @@ -476,13 +476,13 @@ ccl_status_t ccl_coll_build_ring_allreduce(ccl_sched* sched, " recv ", recv_buf); - ccl_status_t status = ccl_status_success; + ccl::status status = ccl::status::success; ccl_coll_build_ring_reduce_scatter(sched, send_buf, recv_buf, count, dtype, op, comm); sched->add_barrier(); - size_t comm_size = comm->size(); + int comm_size = comm->size(); size_t main_block_count = count / comm_size; size_t last_block_count = main_block_count + count % comm_size; std::vector recv_counts(comm_size, main_block_count); diff --git a/src/coll/algorithms/allreduce/allreduce_2d.cpp b/src/coll/algorithms/allreduce/allreduce_2d.cpp index 76ad38245..487de6976 100644 --- a/src/coll/algorithms/allreduce/allreduce_2d.cpp +++ b/src/coll/algorithms/allreduce/allreduce_2d.cpp @@ -18,8 +18,9 @@ #include "common/global/global.hpp" #include "sched/entry/factory/entry_factory.hpp" -ccl_allreduce_2d_builder::ccl_allreduce_2d_builder(size_t base_size, bool switch_dims, ccl_comm* comm) { - +ccl_allreduce_2d_builder::ccl_allreduce_2d_builder(size_t base_size, + bool switch_dims, + ccl_comm* comm) { parent_comm = comm; size_t vector_size = comm->size(); @@ -36,23 +37,19 @@ ccl_allreduce_2d_builder::ccl_allreduce_2d_builder(size_t base_size, bool switch } } - first_dim_comm = std::shared_ptr( - ccl_comm::create_with_colors(first_dim_colors, - ccl::global_data::get().comm_ids.get(), - comm, true /*share_resources*/)); + first_dim_comm = std::shared_ptr(ccl_comm::create_with_colors( + first_dim_colors, ccl::global_data::get().comm_ids.get(), comm, true /*share_resources*/)); - second_dim_comm = std::shared_ptr( - ccl_comm::create_with_colors(second_dim_colors, - ccl::global_data::get().comm_ids.get(), - comm, true /*share_resources*/)); + second_dim_comm = std::shared_ptr(ccl_comm::create_with_colors( + second_dim_colors, ccl::global_data::get().comm_ids.get(), comm, true /*share_resources*/)); if (comm->rank() == 0) { std::string first_dim_ranks, second_dim_ranks; - for (size_t idx = 0; idx < first_dim_comm->size(); idx++) { + for (int idx = 0; idx < first_dim_comm->size(); idx++) { first_dim_ranks += ((idx) ? " " : "") + std::to_string(first_dim_comm->get_global_rank(idx)); } - for (size_t idx = 0; idx < second_dim_comm->size(); idx++) { + for (int idx = 0; idx < second_dim_comm->size(); idx++) { second_dim_ranks += ((idx) ? " " : "") + std::to_string(second_dim_comm->get_global_rank(idx)); } @@ -74,121 +71,120 @@ ccl_allreduce_2d_builder::ccl_allreduce_2d_builder(size_t base_size, bool switch } ccl_allreduce_2d_builder::~ccl_allreduce_2d_builder() { - first_dim_comm.reset(); - second_dim_comm.reset(); + first_dim_comm.reset(); + second_dim_comm.reset(); } static void ccl_allreduce_2d_add_allreduce_allgather(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm, - size_t chunk_idx, - size_t chunk_count) { - - ccl_comm* first_dim_comm = comm->allreduce_2d_builder->get_first_dim_comm(); - ccl_comm* second_dim_comm = comm->allreduce_2d_builder->get_second_dim_comm(); - - size_t dtype_size = dtype.size(); - size_t main_chunk_size = count / chunk_count; - size_t last_chunk_size = main_chunk_size + count % chunk_count; - size_t cnt = (chunk_idx == (chunk_count - 1)) ? last_chunk_size : main_chunk_size; - ccl_buffer rbuf = recv_buf + chunk_idx * main_chunk_size * dtype_size; - - size_t main_block_count = cnt / first_dim_comm->size(); - size_t last_block_count = main_block_count + cnt % first_dim_comm->size(); - size_t ar_count = (first_dim_comm->rank() == (first_dim_comm->size() - 1)) ? last_block_count - : main_block_count; - - if (ar_count) { - /* TODO: add second level selection to distinguish high and low level algorithms */ - ccl_buffer ar_buf = rbuf + first_dim_comm->rank() * main_block_count * dtype_size; - ccl_coll_build_starlike_allreduce( - sched, ar_buf, ar_buf, ar_count, dtype, op, second_dim_comm); - sched->add_barrier(); - } - - std::vector ag_recv_counts(first_dim_comm->size(), main_block_count); - ag_recv_counts[first_dim_comm->size() - 1] = last_block_count; - ccl_coll_build_allgatherv( - sched, rbuf, ar_count, rbuf, ag_recv_counts.data(), dtype, first_dim_comm); + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm, + size_t chunk_idx, + size_t chunk_count) { + ccl_comm* first_dim_comm = comm->allreduce_2d_builder->get_first_dim_comm(); + ccl_comm* second_dim_comm = comm->allreduce_2d_builder->get_second_dim_comm(); + + size_t dtype_size = dtype.size(); + size_t main_chunk_size = count / chunk_count; + size_t last_chunk_size = main_chunk_size + count % chunk_count; + size_t cnt = (chunk_idx == (chunk_count - 1)) ? last_chunk_size : main_chunk_size; + ccl_buffer rbuf = recv_buf + chunk_idx * main_chunk_size * dtype_size; + + size_t main_block_count = cnt / first_dim_comm->size(); + size_t last_block_count = main_block_count + cnt % first_dim_comm->size(); + size_t ar_count = (first_dim_comm->rank() == (first_dim_comm->size() - 1)) ? last_block_count + : main_block_count; + + if (ar_count) { + /* TODO: add second level selection to distinguish high and low level algorithms */ + ccl_buffer ar_buf = rbuf + first_dim_comm->rank() * main_block_count * dtype_size; + ccl_coll_build_starlike_allreduce( + sched, ar_buf, ar_buf, ar_count, dtype, op, second_dim_comm); + sched->add_barrier(); + } + + std::vector ag_recv_counts(first_dim_comm->size(), main_block_count); + ag_recv_counts[first_dim_comm->size() - 1] = last_block_count; + ccl_coll_build_allgatherv( + sched, rbuf, ar_count, rbuf, ag_recv_counts.data(), dtype, first_dim_comm); } static void ccl_allreduce_2d_add_reduce_scatter_allreduce_allgather(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm, - size_t chunk_idx, - size_t chunk_count) { - ccl_comm* first_dim_comm = comm->allreduce_2d_builder->get_first_dim_comm(); - - size_t dtype_size = dtype.size(); - size_t main_chunk_size = count / chunk_count; - size_t last_chunk_size = main_chunk_size + count % chunk_count; - size_t cnt = (chunk_idx == (chunk_count - 1)) ? last_chunk_size : main_chunk_size; - ccl_buffer sbuf = send_buf + chunk_idx * main_chunk_size * dtype_size; - ccl_buffer rbuf = recv_buf + chunk_idx * main_chunk_size * dtype_size; - - ccl_coll_build_reduce_scatter(sched, sbuf, rbuf, cnt, dtype, op, first_dim_comm, true); - sched->add_barrier(); - - if (chunk_idx == (chunk_count - 1) || (chunk_count == 1)) { - ccl_allreduce_2d_add_allreduce_allgather( - sched, send_buf, recv_buf, count, dtype, op, comm, chunk_idx, chunk_count); - } - else { - entry_factory::make_entry( - sched, - chunk_idx, - [send_buf, recv_buf, count, &dtype, op, comm, chunk_idx, chunk_count](ccl_sched* s) { - ccl_allreduce_2d_add_allreduce_allgather( - s, send_buf, recv_buf, count, dtype, op, comm, chunk_idx, chunk_count); - }, - "AR_AG"); - - entry_factory::make_entry( - sched, - chunk_idx + 1, - [send_buf, recv_buf, count, &dtype, op, comm, chunk_idx, chunk_count](ccl_sched* s) { - ccl_allreduce_2d_add_reduce_scatter_allreduce_allgather( - s, send_buf, recv_buf, count, dtype, op, comm, chunk_idx + 1, chunk_count); - }, - "RS_AR_AG"); - } + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm, + size_t chunk_idx, + size_t chunk_count) { + ccl_comm* first_dim_comm = comm->allreduce_2d_builder->get_first_dim_comm(); + + size_t dtype_size = dtype.size(); + size_t main_chunk_size = count / chunk_count; + size_t last_chunk_size = main_chunk_size + count % chunk_count; + size_t cnt = (chunk_idx == (chunk_count - 1)) ? last_chunk_size : main_chunk_size; + ccl_buffer sbuf = send_buf + chunk_idx * main_chunk_size * dtype_size; + ccl_buffer rbuf = recv_buf + chunk_idx * main_chunk_size * dtype_size; + + ccl_coll_build_reduce_scatter(sched, sbuf, rbuf, cnt, dtype, op, first_dim_comm, true); + sched->add_barrier(); + + if (chunk_idx == (chunk_count - 1) || (chunk_count == 1)) { + ccl_allreduce_2d_add_allreduce_allgather( + sched, send_buf, recv_buf, count, dtype, op, comm, chunk_idx, chunk_count); + } + else { + entry_factory::make_entry( + sched, + chunk_idx, + [send_buf, recv_buf, count, &dtype, op, comm, chunk_idx, chunk_count](ccl_sched* s) { + ccl_allreduce_2d_add_allreduce_allgather( + s, send_buf, recv_buf, count, dtype, op, comm, chunk_idx, chunk_count); + }, + "AR_AG"); + + entry_factory::make_entry( + sched, + chunk_idx + 1, + [send_buf, recv_buf, count, &dtype, op, comm, chunk_idx, chunk_count](ccl_sched* s) { + ccl_allreduce_2d_add_reduce_scatter_allreduce_allgather( + s, send_buf, recv_buf, count, dtype, op, comm, chunk_idx + 1, chunk_count); + }, + "RS_AR_AG"); + } } -ccl_status_t ccl_allreduce_2d_builder::build(ccl_sched* sched, +ccl::status ccl_allreduce_2d_builder::build(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, size_t count, const ccl_datatype& dtype, ccl::reduction op) { - CCL_THROW_IF_NOT(sched && send_buf && recv_buf && count, - "incorrect values, sched ", - sched, - ", send ", - send_buf, - " recv ", - recv_buf); + CCL_THROW_IF_NOT(sched && send_buf && recv_buf && count, + "incorrect values, sched ", + sched, + ", send ", + send_buf, + " recv ", + recv_buf); - ccl_status_t status = ccl_status_success; + ccl::status status = ccl::status::success; - size_t chunk_count = ccl::global_data::env().ar2d_chunk_count; + size_t chunk_count = ccl::global_data::env().ar2d_chunk_count; - if (chunk_count == 0) { - LOG_ERROR("unexpected chunk_count"); - chunk_count = 1; - } + if (chunk_count == 0) { + LOG_ERROR("unexpected chunk_count"); + chunk_count = 1; + } - LOG_DEBUG("build 2d allreduce, chunk_count ", chunk_count); + LOG_DEBUG("build 2d allreduce, chunk_count ", chunk_count); - ccl_allreduce_2d_add_reduce_scatter_allreduce_allgather( - sched, send_buf, recv_buf, count, dtype, op, parent_comm, 0 /* chunk_idx */, chunk_count); + ccl_allreduce_2d_add_reduce_scatter_allreduce_allgather( + sched, send_buf, recv_buf, count, dtype, op, parent_comm, 0 /* chunk_idx */, chunk_count); - return status; + return status; } diff --git a/src/coll/algorithms/allreduce/allreduce_2d.hpp b/src/coll/algorithms/allreduce/allreduce_2d.hpp index 01c37531e..5c130c3c4 100644 --- a/src/coll/algorithms/allreduce/allreduce_2d.hpp +++ b/src/coll/algorithms/allreduce/allreduce_2d.hpp @@ -22,31 +22,31 @@ class comm; class ccl_allreduce_2d_builder { public: - ccl_allreduce_2d_builder(size_t base_size, bool switch_dims, ccl_comm* comm); - ~ccl_allreduce_2d_builder(); + ccl_allreduce_2d_builder(size_t base_size, bool switch_dims, ccl_comm* comm); + ~ccl_allreduce_2d_builder(); - ccl_allreduce_2d_builder(const ccl_allreduce_2d_builder&) = delete; - ccl_allreduce_2d_builder(ccl_allreduce_2d_builder&&) = delete; + ccl_allreduce_2d_builder(const ccl_allreduce_2d_builder&) = delete; + ccl_allreduce_2d_builder(ccl_allreduce_2d_builder&&) = delete; - ccl_allreduce_2d_builder& operator=(const ccl_allreduce_2d_builder&) = delete; - ccl_allreduce_2d_builder& operator=(ccl_allreduce_2d_builder&&) = delete; + ccl_allreduce_2d_builder& operator=(const ccl_allreduce_2d_builder&) = delete; + ccl_allreduce_2d_builder& operator=(ccl_allreduce_2d_builder&&) = delete; - ccl_status_t build(ccl_sched* sched, + ccl::status build(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, size_t count, const ccl_datatype& dtype, ccl::reduction op); - ccl_comm* get_first_dim_comm() const { - return first_dim_comm.get(); - } - ccl_comm* get_second_dim_comm() const { - return second_dim_comm.get(); - } + ccl_comm* get_first_dim_comm() const { + return first_dim_comm.get(); + } + ccl_comm* get_second_dim_comm() const { + return second_dim_comm.get(); + } private: - ccl_comm* parent_comm; - std::shared_ptr first_dim_comm; - std::shared_ptr second_dim_comm; + ccl_comm* parent_comm; + std::shared_ptr first_dim_comm; + std::shared_ptr second_dim_comm; }; diff --git a/src/coll/algorithms/allreduce/allreduce_rma.cpp b/src/coll/algorithms/allreduce/allreduce_rma.cpp index 506a836b8..a74c95d92 100644 --- a/src/coll/algorithms/allreduce/allreduce_rma.cpp +++ b/src/coll/algorithms/allreduce/allreduce_rma.cpp @@ -18,113 +18,113 @@ #include "sched/entry/factory/entry_factory.hpp" #include "exec/exec.hpp" -ccl_status_t rma_ring_allreduce_reset_sync_flag(const void* ctx) { +ccl::status rma_ring_allreduce_reset_sync_flag(const void* ctx) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; ar_handler->sync_flag = 0; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_reset_dst_ready_flag(const void* ctx) { +ccl::status rma_ring_allreduce_reset_dst_ready_flag(const void* ctx) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; ar_handler->dst_ready_flag = 0; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_remote_sync_flag_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_remote_sync_flag_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = &(ar_handler->remote_sync_flag_mr); atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_sync_flag_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_sync_flag_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = ar_handler->sync_flag_mr; atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_sync_flags_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_sync_flags_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = ar_handler->sync_flags_mr; atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_send_buf_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_send_buf_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = ar_handler->send_buf_mr; atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_recv_buf_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_recv_buf_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = ar_handler->recv_buf_mr; atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_tmp_buf_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_tmp_buf_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = ar_handler->tmp_buf_mr; atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_dst_ready_flag_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_dst_ready_flag_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = ar_handler->dst_ready_flag_mr; atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_dst_ready_value_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_dst_ready_value_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = ar_handler->dst_ready_value_mr; atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_remote_dst_ready_flag_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_remote_dst_ready_flag_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = &(ar_handler->remote_dst_ready_flag_mr); atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_remote_rs_dst_buf_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_remote_rs_dst_buf_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = &(ar_handler->remote_rs_dst_buf_mr); atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t rma_ring_allreduce_get_remote_recv_buf_mr(const void* ctx, void* field_ptr) { +ccl::status rma_ring_allreduce_get_remote_recv_buf_mr(const void* ctx, void* field_ptr) { ccl_rma_ring_allreduce_handler* ar_handler = (ccl_rma_ring_allreduce_handler*)ctx; atl_mr_t* mr = &(ar_handler->remote_recv_buf_mr); atl_mr_t** mr_ptr = (atl_mr_t**)field_ptr; *mr_ptr = mr; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm) { +ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { int inplace = (send_buf == recv_buf) ? 1 : 0; LOG_DEBUG("build ring rma allreduce (", (inplace) ? "in-place" : "out-of-place", ")"); @@ -136,10 +136,10 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, ", recv ", recv_buf); - ccl_status_t status = ccl_status_success; - size_t comm_size, rank; + ccl::status status = ccl::status::success; + int comm_size, rank; size_t dtype_size = dtype.size(); - size_t idx = 0; + int idx = 0; ccl_buffer tmp_buf; comm_size = comm->size(); rank = comm->rank(); @@ -149,7 +149,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype); sched->add_barrier(); } - return ccl_status_success; + return ccl::status::success; } ccl_rma_ring_allreduce_handler* ar_handler = @@ -216,7 +216,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, sched, ccl_buffer(&ar_handler->tmp_buf_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->src_peer, comm); e->set_field_fn(rma_ring_allreduce_get_tmp_buf_mr, ar_handler); @@ -226,7 +226,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, sched, ccl_buffer(&ar_handler->recv_buf_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->src_peer, comm); e->set_field_fn(rma_ring_allreduce_get_recv_buf_mr, ar_handler); @@ -235,7 +235,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, sched, ccl_buffer(&ar_handler->recv_buf_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->src_peer, comm); e->set_field_fn(rma_ring_allreduce_get_recv_buf_mr, ar_handler); @@ -244,7 +244,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, sched, ccl_buffer(&ar_handler->sync_flag_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->src_peer, comm); e->set_field_fn(rma_ring_allreduce_get_sync_flag_mr, ar_handler); @@ -253,21 +253,21 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, sched, ccl_buffer(&ar_handler->remote_rs_dst_buf_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->dst_peer, comm); entry_factory::make_entry( sched, ccl_buffer(&ar_handler->remote_recv_buf_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->dst_peer, comm); entry_factory::make_entry( sched, ccl_buffer(&ar_handler->remote_sync_flag_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->dst_peer, comm); @@ -276,7 +276,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, sched, ccl_buffer(ar_handler->dst_ready_flag_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->dst_peer, comm); e->set_field_fn(rma_ring_allreduce_get_dst_ready_flag_mr, @@ -285,7 +285,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, sched, ccl_buffer(&ar_handler->remote_dst_ready_flag_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->src_peer, comm); } @@ -301,7 +301,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, ccl_buffer(&ar_handler->dst_ready_value, sizeof(uint64_t)), (atl_mr_t*)nullptr, /* src_mr */ sizeof(uint64_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->src_peer, (atl_mr_t*)nullptr, /* dst_mr */ 0 /* dst_buf_offset */, @@ -320,7 +320,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, sched, rma_ring_allreduce_reset_dst_ready_flag, ar_handler); } - size_t block_idx = rank; + int block_idx = rank; size_t main_block_count = count / comm_size; size_t buf_offset; @@ -362,7 +362,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, ccl_buffer(&ar_handler->sync_flags[idx], sizeof(uint64_t)), (atl_mr_t*)nullptr, /* src_mr */ sizeof(uint64_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->dst_peer, (atl_mr_t*)nullptr, /* dst_mr */ 0 /* dst_buf_offset */, @@ -423,7 +423,7 @@ ccl_status_t ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, ccl_buffer(&ar_handler->sync_flags[flag_idx_offset + idx], sizeof(uint64_t)), (atl_mr_t*)nullptr, /* src_mr */ sizeof(uint64_t), - ccl_datatype_char, + ccl_datatype_int8, ar_handler->dst_peer, (atl_mr_t*)nullptr, /* dst_mr */ 0 /* dst_buf_offset */, diff --git a/src/coll/algorithms/allreduce/allreduce_rma.hpp b/src/coll/algorithms/allreduce/allreduce_rma.hpp index 2c8013a83..76e2075c8 100644 --- a/src/coll/algorithms/allreduce/allreduce_rma.hpp +++ b/src/coll/algorithms/allreduce/allreduce_rma.hpp @@ -20,8 +20,8 @@ typedef struct { int wait_dst; - size_t src_peer; - size_t dst_peer; + int src_peer; + int dst_peer; volatile uint64_t sync_flag; // src side will write here the index of iteration it completed atl_mr_t* sync_flag_mr; diff --git a/src/coll/algorithms/alltoall.cpp b/src/coll/algorithms/alltoall.cpp index 17db2b02e..2bd43f5cf 100644 --- a/src/coll/algorithms/alltoall.cpp +++ b/src/coll/algorithms/alltoall.cpp @@ -16,14 +16,14 @@ #include "coll/algorithms/algorithms.hpp" #include "sched/entry/factory/entry_factory.hpp" -ccl_status_t ccl_coll_build_direct_alltoall(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl_comm* comm) { +ccl::status ccl_coll_build_direct_alltoall(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl_comm* comm) { LOG_DEBUG("build direct alltoall"); entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype, comm); - return ccl_status_success; + return ccl::status::success; } diff --git a/src/coll/algorithms/alltoallv.cpp b/src/coll/algorithms/alltoallv.cpp index 1e180ce3c..f9675483b 100644 --- a/src/coll/algorithms/alltoallv.cpp +++ b/src/coll/algorithms/alltoallv.cpp @@ -26,22 +26,22 @@ #include "sched/entry/factory/chunked_entry_factory.hpp" #include "sched/entry/factory/entry_factory.hpp" -ccl_status_t ccl_coll_build_direct_alltoallv(ccl_sched* sched, - ccl_buffer send_buf, - const size_t* send_counts, - ccl_buffer recv_buf, - const size_t* recv_counts, - const ccl_datatype& dtype, - ccl_comm* comm) { +ccl::status ccl_coll_build_direct_alltoallv(ccl_sched* sched, + ccl_buffer send_buf, + const size_t* send_counts, + ccl_buffer recv_buf, + const size_t* recv_counts, + const ccl_datatype& dtype, + ccl_comm* comm) { LOG_DEBUG("build direct alltoallv"); entry_factory::make_entry( sched, send_buf, send_counts, recv_buf, recv_counts, dtype, comm); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t ccl_coll_add_scatter_alltoallv_barriers(std::vector& scheds, - size_t sched_idx) { +ccl::status ccl_coll_add_scatter_alltoallv_barriers(std::vector& scheds, + size_t sched_idx) { ssize_t max_ops = ccl::global_data::env().alltoall_scatter_max_ops; if (max_ops != CCL_ENV_SIZET_NOT_SPECIFIED) { @@ -56,23 +56,23 @@ ccl_status_t ccl_coll_add_scatter_alltoallv_barriers(std::vector& sc } } - return ccl_status_success; + return ccl::status::success; } -ccl_status_t ccl_coll_calculate_alltoallv_counts(const ccl_coll_param& coll_param, - std::vector& send_counts, - std::vector& recv_counts, - std::vector& send_offsets, - std::vector& recv_offsets, - size_t& total_send_count, - size_t& total_recv_count, - size_t& total_send_bytes, - size_t& total_recv_bytes) { +ccl::status ccl_coll_calculate_alltoallv_counts(const ccl_coll_param& coll_param, + std::vector& send_counts, + std::vector& recv_counts, + std::vector& send_offsets, + std::vector& recv_offsets, + size_t& total_send_count, + size_t& total_recv_count, + size_t& total_send_bytes, + size_t& total_recv_bytes) { ccl_coll_type coll_type = coll_param.ctype; ccl_comm* comm = coll_param.comm; const ccl_datatype& dtype = coll_param.dtype; - size_t comm_size = comm->size(); + int comm_size = comm->size(); size_t dtype_size = dtype.size(); if (coll_type == ccl_coll_alltoall) { @@ -91,7 +91,7 @@ ccl_status_t ccl_coll_calculate_alltoallv_counts(const ccl_coll_param& coll_para send_offsets.resize(comm_size, 0); recv_offsets.resize(comm_size, 0); - for (size_t idx = 1; idx < comm_size; idx++) { + for (int idx = 1; idx < comm_size; idx++) { send_offsets[idx] = send_offsets[idx - 1] + send_counts[idx - 1] * dtype_size; recv_offsets[idx] = recv_offsets[idx - 1] + recv_counts[idx - 1] * dtype_size; } @@ -111,19 +111,19 @@ ccl_status_t ccl_coll_calculate_alltoallv_counts(const ccl_coll_param& coll_para ", total_recv_bytes ", total_recv_bytes); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, - std::vector& scheds, - const ccl_coll_param& coll_param) { +ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, + std::vector& scheds, + const ccl_coll_param& coll_param) { LOG_DEBUG("build naive alltoallv"); ccl_comm* comm = coll_param.comm; const ccl_datatype& dtype = coll_param.dtype; - size_t comm_rank = comm->rank(); - size_t comm_size = comm->size(); + int comm_rank = comm->rank(); + int comm_size = comm->size(); size_t sched_count = scheds.size(); size_t dtype_size = dtype.size(); @@ -159,7 +159,7 @@ ccl_status_t ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, dtype); } - for (size_t idx = 0; idx < comm_size; idx++) { + for (int idx = 0; idx < comm_size; idx++) { if (idx == comm_rank) continue; @@ -203,19 +203,19 @@ ccl_status_t ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, } } - return ccl_status_success; + return ccl::status::success; } -ccl_status_t ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, - std::vector& scheds, - const ccl_coll_param& coll_param) { +ccl::status ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, + std::vector& scheds, + const ccl_coll_param& coll_param) { LOG_DEBUG("build scatter alltoall"); ccl_comm* comm = coll_param.comm; const ccl_datatype& dtype = coll_param.dtype; - size_t comm_rank = comm->rank(); - size_t comm_size = comm->size(); + int comm_rank = comm->rank(); + int comm_size = comm->size(); size_t sched_count = scheds.size(); size_t dtype_size = dtype.size(); @@ -255,8 +255,8 @@ ccl_status_t ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, dtype); } - for (size_t idx = 0; idx < comm_size; idx++) { - size_t src = (comm_rank + idx) % comm_size; + for (int idx = 0; idx < comm_size; idx++) { + int src = (comm_rank + idx) % comm_size; if (src == comm_rank) continue; @@ -281,8 +281,8 @@ ccl_status_t ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, ccl_coll_add_scatter_alltoallv_barriers(scheds, sched_idx); } - for (size_t idx = 0; idx < comm_size; idx++) { - size_t dst = (comm_rank - idx + comm_size) % comm_size; + for (int idx = 0; idx < comm_size; idx++) { + int dst = (comm_rank - idx + comm_size) % comm_size; if (dst == comm_rank) continue; @@ -305,11 +305,11 @@ ccl_status_t ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, } if (!inplace) - return ccl_status_success; + return ccl::status::success; main_sched->sync_partial_scheds(); - for (size_t idx = 0; idx < comm_size; idx++) { + for (int idx = 0; idx < comm_size; idx++) { if (idx == comm_rank) continue; @@ -325,19 +325,19 @@ ccl_status_t ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, dtype); } - return ccl_status_success; + return ccl::status::success; } -ccl_status_t ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sched, - std::vector& scheds, - const ccl_coll_param& coll_param) { +ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sched, + std::vector& scheds, + const ccl_coll_param& coll_param) { LOG_DEBUG("build scatter_barrier alltoallv"); ccl_comm* comm = coll_param.comm; const ccl_datatype& dtype = coll_param.dtype; - size_t comm_rank = comm->rank(); - size_t comm_size = comm->size(); + int comm_rank = comm->rank(); + int comm_size = comm->size(); size_t sched_count = scheds.size(); size_t dtype_size = dtype.size(); @@ -394,8 +394,8 @@ ccl_status_t ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sch dtype); } - for (size_t idx = 0; idx < comm_size; idx++) { - size_t src = (comm_rank + idx) % comm_size; + for (int idx = 0; idx < comm_size; idx++) { + int src = (comm_rank + idx) % comm_size; if (src == comm_rank) continue; @@ -423,8 +423,8 @@ ccl_status_t ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sch ccl_coll_add_scatter_alltoallv_barriers(recv_scheds, sched_idx); } - for (size_t idx = 0; idx < comm_size; idx++) { - size_t dst = (comm_rank - idx + comm_size) % comm_size; + for (int idx = 0; idx < comm_size; idx++) { + int dst = (comm_rank - idx + comm_size) % comm_size; if (dst == comm_rank) continue; @@ -447,11 +447,11 @@ ccl_status_t ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sch } if (!inplace) - return ccl_status_success; + return ccl::status::success; main_sched->sync_partial_scheds(); - for (size_t idx = 0; idx < comm_size; idx++) { + for (int idx = 0; idx < comm_size; idx++) { if (idx == comm_rank) continue; @@ -467,5 +467,5 @@ ccl_status_t ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sch dtype); } - return ccl_status_success; + return ccl::status::success; } diff --git a/src/coll/algorithms/barrier.cpp b/src/coll/algorithms/barrier.cpp index f34184512..5aa05e094 100644 --- a/src/coll/algorithms/barrier.cpp +++ b/src/coll/algorithms/barrier.cpp @@ -22,17 +22,17 @@ #include "coll/algorithms/algorithms.hpp" #include "sched/entry/factory/entry_factory.hpp" -ccl_status_t ccl_coll_build_direct_barrier(ccl_sched* sched, ccl_comm* comm) { +ccl::status ccl_coll_build_direct_barrier(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build direct barrier"); entry_factory::make_entry(sched, comm); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t ccl_coll_build_dissemination_barrier(ccl_sched* sched, ccl_comm* comm) { +ccl::status ccl_coll_build_dissemination_barrier(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build dissemination barrier"); - ccl_status_t status = ccl_status_success; + ccl::status status = ccl::status::success; int size, rank, src, dst, mask; size = comm->size(); rank = comm->rank(); @@ -44,8 +44,8 @@ ccl_status_t ccl_coll_build_dissemination_barrier(ccl_sched* sched, ccl_comm* co while (mask < size) { dst = (rank + mask) % size; src = (rank - mask + size) % size; - entry_factory::make_entry(sched, ccl_buffer(), 0, ccl_datatype_char, dst, comm); - entry_factory::make_entry(sched, ccl_buffer(), 0, ccl_datatype_char, src, comm); + entry_factory::make_entry(sched, ccl_buffer(), 0, ccl_datatype_int8, dst, comm); + entry_factory::make_entry(sched, ccl_buffer(), 0, ccl_datatype_int8, src, comm); sched->add_barrier(); mask <<= 1; } diff --git a/src/coll/algorithms/bcast.cpp b/src/coll/algorithms/bcast.cpp index ec5de1a97..59dcf76e0 100644 --- a/src/coll/algorithms/bcast.cpp +++ b/src/coll/algorithms/bcast.cpp @@ -13,222 +13,222 @@ See the License for the specific language governing permissions and limitations under the License. */ - -/* -* -* (C) 2001 by Argonne National Laboratory. -* See COPYRIGHT in top-level directory. -*/ - -#include "coll/algorithms/algorithms.hpp" -#include "sched/entry/factory/entry_factory.hpp" - -#define MIN(a, b) std::min(a, b) - -ccl_status_t ccl_coll_build_direct_bcast(ccl_sched* sched, - ccl_buffer buf, - size_t count, - const ccl_datatype& dtype, - size_t root, - ccl_comm* comm) { - LOG_DEBUG("build direct bcast"); - - entry_factory::make_entry(sched, buf, count, dtype, root, comm); - return ccl_status_success; -} - -ccl_status_t ccl_coll_build_naive_bcast(ccl_sched* sched, - ccl_buffer buf, - size_t count, - const ccl_datatype& dtype, - size_t root, - ccl_comm* comm) { - LOG_DEBUG("build naive bcast"); - - ccl_status_t status = ccl_status_success; - - size_t rank = comm->rank(); - size_t comm_size = comm->size(); - size_t idx; - - if (comm_size == 1) - goto fn_exit; - - if (rank == root) { - for (idx = 0; idx < comm_size; idx++) { - if (idx != rank) { - entry_factory::make_entry(sched, buf, count, dtype, idx, comm); - } - } - } - else { - entry_factory::make_entry(sched, buf, count, dtype, root, comm); - } - -fn_exit: - return status; -} - -ccl_status_t ccl_coll_build_scatter_for_bcast(ccl_sched* sched, - ccl_buffer tmp_buf, - size_t root, - size_t nbytes, - ccl_comm* comm) { - LOG_DEBUG("build scatter_for_bcast"); - - ccl_status_t status = ccl_status_success; - int rank, local_root, comm_size, src, dst; - int relative_rank, mask; - int scatter_size, curr_size, recv_size, send_size; - - comm_size = comm->size(); - rank = comm->rank(); - local_root = static_cast(root); - relative_rank = (rank >= local_root) ? rank - local_root : rank - local_root + comm_size; - - /* The scatter algorithm divides the buffer into nprocs pieces and - * scatters them among the processes. Root gets the first piece, - * root+1 gets the second piece, and so forth. Uses the same - * binomial tree algorithm as above. Ceiling division is used to - * compute the size of each piece. This means some processes may - * not get any data. For example if bufsize = 97 and nprocs = 16, - * ranks 15 and 16 will get 0 data. On each process, the scattered - * data is stored at the same offset in the buffer as it is on the - * root process. */ - - scatter_size = (nbytes + comm_size - 1) / comm_size; /* ceiling division */ - curr_size = (rank == local_root) ? nbytes : 0; /* root starts with all the data */ - - mask = 0x1; - while (mask < comm_size) { - if (relative_rank & mask) { - src = rank - mask; - if (src < 0) - src += comm_size; - - /* compute the exact recv_size to avoid writing this NBC - * in callback style */ - recv_size = nbytes - (relative_rank * scatter_size); - if (recv_size < 0) - recv_size = 0; - - curr_size = recv_size; - - if (recv_size > 0) { - entry_factory::make_entry(sched, - tmp_buf + relative_rank * scatter_size, - recv_size, - ccl_datatype_char, - src, - comm); - sched->add_barrier(); - } - break; - } - mask <<= 1; - } - - /* This process is responsible for all processes that have bits - * set from the LSB upto (but not including) mask. Because of the - * "not including", we start by shifting mask back down one. */ - - mask >>= 1; - while (mask > 0) { - if (relative_rank + mask < comm_size) { - send_size = curr_size - scatter_size * mask; - - /* mask is also the size of this process's subtree */ - - if (send_size > 0) { - dst = rank + mask; - if (dst >= comm_size) - dst -= comm_size; - - entry_factory::make_entry( - sched, - tmp_buf + scatter_size * (relative_rank + mask), - send_size, - ccl_datatype_char, - dst, - comm); - sched->add_barrier(); - curr_size -= send_size; - } - } - mask >>= 1; - } - - return status; -} - -ccl_status_t ccl_coll_build_scatter_ring_allgather_bcast(ccl_sched* sched, - ccl_buffer buf, - size_t count, - const ccl_datatype& dtype, - size_t root, - ccl_comm* comm) { - LOG_DEBUG("build scatter_ring_allgather bcast"); - - ccl_status_t status = ccl_status_success; - - int comm_size, rank, nbytes; - int scatter_size, curr_size; - int i, j, jnext, left, right; - size_t dtype_size = dtype.size(); - - comm_size = comm->size(); - rank = comm->rank(); - - ccl_buffer tmp_buf(buf); - - /* If there is only one process, return */ - if (comm_size == 1) - goto fn_exit; - - nbytes = dtype_size * count; - - CCL_CALL(ccl_coll_build_scatter_for_bcast(sched, tmp_buf, root, nbytes, comm)); - - /* this is the block size used for the scatter operation */ - scatter_size = (nbytes + comm_size - 1) / comm_size; /* ceiling division */ - - /* curr_size is the amount of data that this process now has stored in - * buffer at byte offset (rank*scatter_size) */ - curr_size = MIN(scatter_size, (nbytes - (rank * scatter_size))); - if (curr_size < 0) - curr_size = 0; - - /* long-message allgather or medium-size but non-power-of-two. use ring algorithm. */ - - left = (comm_size + rank - 1) % comm_size; - right = (rank + 1) % comm_size; - - j = rank; - jnext = left; - for (i = 1; i < comm_size; i++) { - int left_count, right_count, left_disp, right_disp, rel_j, rel_jnext; - - rel_j = (j - root + comm_size) % comm_size; - rel_jnext = (jnext - root + comm_size) % comm_size; - left_count = MIN(scatter_size, (nbytes - rel_jnext * scatter_size)); - if (left_count < 0) - left_count = 0; - left_disp = rel_jnext * scatter_size; - right_count = MIN(scatter_size, (nbytes - rel_j * scatter_size)); - if (right_count < 0) - right_count = 0; - right_disp = rel_j * scatter_size; - entry_factory::make_entry( - sched, tmp_buf + right_disp, right_count, ccl_datatype_char, right, comm); - /* sendrecv, no barrier here */ - entry_factory::make_entry( - sched, tmp_buf + left_disp, left_count, ccl_datatype_char, left, comm); - sched->add_barrier(); - - j = jnext; - jnext = (comm_size + jnext - 1) % comm_size; - } - -fn_exit: - return status; -} + +/* +* +* (C) 2001 by Argonne National Laboratory. +* See COPYRIGHT in top-level directory. +*/ + +#include "coll/algorithms/algorithms.hpp" +#include "sched/entry/factory/entry_factory.hpp" + +#define MIN(a, b) std::min(a, b) + +ccl::status ccl_coll_build_direct_bcast(ccl_sched* sched, + ccl_buffer buf, + size_t count, + const ccl_datatype& dtype, + int root, + ccl_comm* comm) { + LOG_DEBUG("build direct bcast"); + + entry_factory::make_entry(sched, buf, count, dtype, root, comm); + return ccl::status::success; +} + +ccl::status ccl_coll_build_naive_bcast(ccl_sched* sched, + ccl_buffer buf, + size_t count, + const ccl_datatype& dtype, + int root, + ccl_comm* comm) { + LOG_DEBUG("build naive bcast"); + + ccl::status status = ccl::status::success; + + int rank = comm->rank(); + int comm_size = comm->size(); + int idx; + + if (comm_size == 1) + goto fn_exit; + + if (rank == root) { + for (idx = 0; idx < comm_size; idx++) { + if (idx != rank) { + entry_factory::make_entry(sched, buf, count, dtype, idx, comm); + } + } + } + else { + entry_factory::make_entry(sched, buf, count, dtype, root, comm); + } + +fn_exit: + return status; +} + +ccl::status ccl_coll_build_scatter_for_bcast(ccl_sched* sched, + ccl_buffer tmp_buf, + int root, + size_t nbytes, + ccl_comm* comm) { + LOG_DEBUG("build scatter_for_bcast"); + + ccl::status status = ccl::status::success; + int rank, local_root, comm_size, src, dst; + int relative_rank, mask; + int scatter_size, curr_size, recv_size, send_size; + + comm_size = comm->size(); + rank = comm->rank(); + local_root = static_cast(root); + relative_rank = (rank >= local_root) ? rank - local_root : rank - local_root + comm_size; + + /* The scatter algorithm divides the buffer into nprocs pieces and + * scatters them among the processes. Root gets the first piece, + * root+1 gets the second piece, and so forth. Uses the same + * binomial tree algorithm as above. Ceiling division is used to + * compute the size of each piece. This means some processes may + * not get any data. For example if bufsize = 97 and nprocs = 16, + * ranks 15 and 16 will get 0 data. On each process, the scattered + * data is stored at the same offset in the buffer as it is on the + * root process. */ + + scatter_size = (nbytes + comm_size - 1) / comm_size; /* ceiling division */ + curr_size = (rank == local_root) ? nbytes : 0; /* root starts with all the data */ + + mask = 0x1; + while (mask < comm_size) { + if (relative_rank & mask) { + src = rank - mask; + if (src < 0) + src += comm_size; + + /* compute the exact recv_size to avoid writing this NBC + * in callback style */ + recv_size = nbytes - (relative_rank * scatter_size); + if (recv_size < 0) + recv_size = 0; + + curr_size = recv_size; + + if (recv_size > 0) { + entry_factory::make_entry(sched, + tmp_buf + relative_rank * scatter_size, + recv_size, + ccl_datatype_int8, + src, + comm); + sched->add_barrier(); + } + break; + } + mask <<= 1; + } + + /* This process is responsible for all processes that have bits + * set from the LSB upto (but not including) mask. Because of the + * "not including", we start by shifting mask back down one. */ + + mask >>= 1; + while (mask > 0) { + if (relative_rank + mask < comm_size) { + send_size = curr_size - scatter_size * mask; + + /* mask is also the size of this process's subtree */ + + if (send_size > 0) { + dst = rank + mask; + if (dst >= comm_size) + dst -= comm_size; + + entry_factory::make_entry( + sched, + tmp_buf + scatter_size * (relative_rank + mask), + send_size, + ccl_datatype_int8, + dst, + comm); + sched->add_barrier(); + curr_size -= send_size; + } + } + mask >>= 1; + } + + return status; +} + +ccl::status ccl_coll_build_scatter_ring_allgather_bcast(ccl_sched* sched, + ccl_buffer buf, + size_t count, + const ccl_datatype& dtype, + int root, + ccl_comm* comm) { + LOG_DEBUG("build scatter_ring_allgather bcast"); + + ccl::status status = ccl::status::success; + + int comm_size, rank, nbytes; + int scatter_size, curr_size; + int i, j, jnext, left, right; + size_t dtype_size = dtype.size(); + + comm_size = comm->size(); + rank = comm->rank(); + + ccl_buffer tmp_buf(buf); + + /* If there is only one process, return */ + if (comm_size == 1) + goto fn_exit; + + nbytes = dtype_size * count; + + CCL_CALL(ccl_coll_build_scatter_for_bcast(sched, tmp_buf, root, nbytes, comm)); + + /* this is the block size used for the scatter operation */ + scatter_size = (nbytes + comm_size - 1) / comm_size; /* ceiling division */ + + /* curr_size is the amount of data that this process now has stored in + * buffer at byte offset (rank*scatter_size) */ + curr_size = MIN(scatter_size, (nbytes - (rank * scatter_size))); + if (curr_size < 0) + curr_size = 0; + + /* long-message allgather or medium-size but non-power-of-two. use ring algorithm. */ + + left = (comm_size + rank - 1) % comm_size; + right = (rank + 1) % comm_size; + + j = rank; + jnext = left; + for (i = 1; i < comm_size; i++) { + int left_count, right_count, left_disp, right_disp, rel_j, rel_jnext; + + rel_j = (j - root + comm_size) % comm_size; + rel_jnext = (jnext - root + comm_size) % comm_size; + left_count = MIN(scatter_size, (nbytes - rel_jnext * scatter_size)); + if (left_count < 0) + left_count = 0; + left_disp = rel_jnext * scatter_size; + right_count = MIN(scatter_size, (nbytes - rel_j * scatter_size)); + if (right_count < 0) + right_count = 0; + right_disp = rel_j * scatter_size; + entry_factory::make_entry( + sched, tmp_buf + right_disp, right_count, ccl_datatype_int8, right, comm); + /* sendrecv, no barrier here */ + entry_factory::make_entry( + sched, tmp_buf + left_disp, left_count, ccl_datatype_int8, left, comm); + sched->add_barrier(); + + j = jnext; + jnext = (comm_size + jnext - 1) % comm_size; + } + +fn_exit: + return status; +} diff --git a/src/coll/algorithms/double_tree_ops.cpp b/src/coll/algorithms/double_tree_ops.cpp index a981b867e..e124673f2 100644 --- a/src/coll/algorithms/double_tree_ops.cpp +++ b/src/coll/algorithms/double_tree_ops.cpp @@ -143,16 +143,16 @@ static void reduce_bcast_tree(const ccl_bin_tree& tree, } } -ccl_status_t ccl_coll_build_double_tree_op(ccl_sched* sched, - ccl_coll_type coll_type, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - const ccl_double_tree& dtree, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; +ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, + ccl_coll_type coll_type, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + const ccl_double_tree& dtree, + ccl_comm* comm) { + ccl::status status = ccl::status::success; LOG_DEBUG("build double tree ", ccl_coll_type_to_str(coll_type)); diff --git a/src/coll/algorithms/reduce.cpp b/src/coll/algorithms/reduce.cpp index ee99e7f93..65707afed 100644 --- a/src/coll/algorithms/reduce.cpp +++ b/src/coll/algorithms/reduce.cpp @@ -51,32 +51,32 @@ n.(1+(p-1)/p).gamma */ -ccl_status_t ccl_coll_build_direct_reduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - size_t root, - ccl_comm* comm) { +ccl::status ccl_coll_build_direct_reduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + int root, + ccl_comm* comm) { LOG_DEBUG("build direct reduce"); entry_factory::make_entry( sched, send_buf, recv_buf, count, dtype, reduction, root, comm); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - size_t root, - ccl_comm* comm) { +ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + int root, + ccl_comm* comm) { LOG_DEBUG("build Rabenseifner's reduce"); - ccl_status_t status = ccl_status_success; + ccl::status status = ccl::status::success; int i, j, comm_size, rank, local_root, pof2; int rem, dst, new_rank, new_dst, mask, send_idx, recv_idx, last_idx; @@ -347,17 +347,17 @@ ccl_status_t ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, return status; } -ccl_status_t ccl_coll_build_binomial_reduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - size_t root, - ccl_comm* comm) { +ccl::status ccl_coll_build_binomial_reduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + int root, + ccl_comm* comm) { LOG_DEBUG("build binomial reduce"); - ccl_status_t status = ccl_status_success; + ccl::status status = ccl::status::success; int comm_size, rank, local_root; int mask, relrank, source, lroot; diff --git a/src/coll/algorithms/reduce_scatter.cpp b/src/coll/algorithms/reduce_scatter.cpp index 659c9ec7b..4a6774ff7 100644 --- a/src/coll/algorithms/reduce_scatter.cpp +++ b/src/coll/algorithms/reduce_scatter.cpp @@ -23,29 +23,27 @@ #include "coll/algorithms/algorithms.hpp" #include "sched/entry/factory/entry_factory.hpp" -ccl_status_t ccl_coll_build_direct_reduce_scatter(ccl_sched* sched, +ccl::status ccl_coll_build_direct_reduce_scatter(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, size_t recv_count, const ccl_datatype& dtype, ccl::reduction reduction, - ccl_comm* comm) -{ + ccl_comm* comm) { LOG_DEBUG("build direct reduce_scatter"); entry_factory::make_entry( sched, send_buf, recv_buf, recv_count, dtype, reduction, comm); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t ccl_coll_build_ring_reduce_scatter_block(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t recv_count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm) -{ +ccl::status ccl_coll_build_ring_reduce_scatter_block(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t recv_count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { CCL_THROW_IF_NOT(sched && send_buf && recv_buf, "incorrect values, sched ", sched, @@ -55,37 +53,31 @@ ccl_status_t ccl_coll_build_ring_reduce_scatter_block(ccl_sched* sched, recv_buf); int inplace = (send_buf == recv_buf) ? 1 : 0; - LOG_DEBUG("build ring reduce_scatter_block: ", - inplace ? "in-place" : "out-of-place"); + LOG_DEBUG("build ring reduce_scatter_block: ", inplace ? "in-place" : "out-of-place"); - ccl_status_t status = ccl_status_success; - size_t comm_size, rank, idx; + ccl::status status = ccl::status::success; + int comm_size, rank, idx; size_t dtype_size = dtype.size(); - size_t src, dst; + int src, dst; comm_size = comm->size(); rank = comm->rank(); if (recv_count == 0) { - return ccl_status_success; + return ccl::status::success; } if (!inplace) { /* copy local data into recv_buf */ entry_factory::make_entry( - sched, - send_buf + rank * recv_count * dtype_size, - recv_buf, - recv_count, - dtype); + sched, send_buf + rank * recv_count * dtype_size, recv_buf, recv_count, dtype); } /* allocate temporary buffer to store incoming data */ ccl_buffer tmp_buf = sched->alloc_buffer(recv_count * dtype_size); for (idx = 1; idx < comm_size; idx++) { - src = (comm_size + rank - idx) % comm_size; dst = (rank + idx) % comm_size; @@ -93,59 +85,31 @@ ccl_status_t ccl_coll_build_ring_reduce_scatter_block(ccl_sched* sched, * needs from src into tmp_recvbuf */ if (!inplace) { entry_factory::make_entry( - sched, - send_buf + dst * recv_count * dtype_size, - recv_count, - dtype, - dst, - comm); - - entry_factory::make_entry( - sched, - tmp_buf, - recv_count, - dtype, - src, - comm); + sched, send_buf + dst * recv_count * dtype_size, recv_count, dtype, dst, comm); + + entry_factory::make_entry(sched, tmp_buf, recv_count, dtype, src, comm); } else { entry_factory::make_entry( - sched, - recv_buf + dst * recv_count * dtype_size, - recv_count, - dtype, - dst, - comm); - - entry_factory::make_entry( - sched, - tmp_buf, - recv_count, - dtype, - src, - comm); + sched, recv_buf + dst * recv_count * dtype_size, recv_count, dtype, dst, comm); + + entry_factory::make_entry(sched, tmp_buf, recv_count, dtype, src, comm); } sched->add_barrier(); if (!inplace) { entry_factory::make_entry( - sched, - tmp_buf, - recv_count, - recv_buf, - nullptr, - dtype, - op); - } else { - entry_factory::make_entry( - sched, - tmp_buf, - recv_count, - recv_buf + rank * recv_count * dtype_size, - nullptr, - dtype, - op); + sched, tmp_buf, recv_count, recv_buf, nullptr, dtype, op); + } + else { + entry_factory::make_entry(sched, + tmp_buf, + recv_count, + recv_buf + rank * recv_count * dtype_size, + nullptr, + dtype, + op); } } @@ -153,25 +117,20 @@ ccl_status_t ccl_coll_build_ring_reduce_scatter_block(ccl_sched* sched, * recv_buf. already done for rank 0 */ if (inplace && (rank != 0)) { entry_factory::make_entry( - sched, - recv_buf + rank * recv_count * dtype_size, - recv_buf, - recv_count, - dtype); + sched, recv_buf + rank * recv_count * dtype_size, recv_buf, recv_count, dtype); } return status; } /* behaves like reduce_scatter_block but last block may contain more elements */ -ccl_status_t ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t send_count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm) { - +ccl::status ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t send_count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { LOG_DEBUG("build ring reduce_scatter"); CCL_THROW_IF_NOT(sched && send_buf && recv_buf, @@ -182,23 +141,24 @@ ccl_status_t ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, " recv ", recv_buf); - ccl_status_t status = ccl_status_success; - size_t comm_size, rank; + ccl::status status = ccl::status::success; + int comm_size, rank; size_t dtype_size = dtype.size(); comm_size = comm->size(); rank = comm->rank(); - size_t src = (comm_size + rank - 1) % comm_size; - size_t dst = (comm_size + rank + 1) % comm_size; + int src = (comm_size + rank - 1) % comm_size; + int dst = (comm_size + rank + 1) % comm_size; size_t count = send_count; size_t bytes = count * dtype_size; - size_t chunk_count = (bytes >= ccl::global_data::env().rs_min_chunk_size && - count >= ccl::global_data::env().rs_chunk_count && count >= comm_size) - ? ccl::global_data::env().rs_chunk_count - : 1; + size_t chunk_count = + (bytes >= ccl::global_data::env().rs_min_chunk_size && + count >= ccl::global_data::env().rs_chunk_count && (int)count >= comm_size) + ? ccl::global_data::env().rs_chunk_count + : 1; while ((chunk_count > 1) && (bytes / (comm_size * chunk_count) < ccl::global_data::env().rs_min_chunk_size)) { @@ -221,7 +181,7 @@ ccl_status_t ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype); sched->add_barrier(); } - return ccl_status_success; + return ccl::status::success; } ccl_buffer tmp_buf; @@ -238,10 +198,10 @@ ccl_status_t ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, /* the final reduction result on last iteration in corresponsing block */ /* block = group of ~ equal-sized chunks */ - size_t block_idx = (rank + comm_size - 1) % comm_size; + int block_idx = (rank + comm_size - 1) % comm_size; size_t main_block_size = count / comm_size; size_t last_block_size = main_block_size + count % comm_size; - size_t send_block_idx, recv_block_idx; + int send_block_idx, recv_block_idx; size_t send_block_size, recv_block_size; size_t send_block_offset, recv_block_offset; @@ -257,7 +217,7 @@ ccl_status_t ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, ccl_recv_reduce_result_buf_type recv_reduce_result_type; - for (size_t idx = 0; idx < (comm_size - 1); idx++) { + for (int idx = 0; idx < (comm_size - 1); idx++) { send_block_idx = block_idx; recv_block_idx = (comm_size + block_idx - 1) % comm_size; diff --git a/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp b/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp index 104d98cba..2daba7276 100644 --- a/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp +++ b/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_type_traits.hpp" +#include "oneapi/ccl/type_traits.hpp" #include "coll/algorithms/sparse_allreduce/sparse_handler.hpp" #include "sched/entry/factory/entry_factory.hpp" @@ -74,7 +74,7 @@ break; \ default: \ CCL_FATAL("unexpected sparse_allreduce_algo ", ccl_coll_algorithm_to_str(algo)); \ - return ccl_status_invalid_arguments; \ + return ccl::status::invalid_arguments; \ } \ } while (0) @@ -84,25 +84,14 @@ case ccl::datatype::float32: \ CCL_SPARSE_ALLREDUCE_SELECT_ALGO(itype, float, algo); \ break; \ - case ccl::datatype::float64: \ - CCL_SPARSE_ALLREDUCE_SELECT_ALGO(itype, double, algo); \ - break; \ - case ccl::datatype::int8: CCL_SPARSE_ALLREDUCE_SELECT_ALGO(itype, char, algo); break; \ - case ccl::datatype::int32: CCL_SPARSE_ALLREDUCE_SELECT_ALGO(itype, int, algo); break; \ - case ccl::datatype::int64: \ - CCL_SPARSE_ALLREDUCE_SELECT_ALGO(itype, int64_t, algo); \ - break; \ - case ccl::datatype::uint64: \ - CCL_SPARSE_ALLREDUCE_SELECT_ALGO(itype, uint64_t, algo); \ - break; \ case ccl::datatype::bfloat16: \ - CCL_SPARSE_ALLREDUCE_SELECT_ALGO(itype, ccl::bf16, algo); \ + CCL_SPARSE_ALLREDUCE_SELECT_ALGO(itype, ccl::bfloat16, algo); \ break; \ default: \ CCL_FATAL("value datatype ", \ ccl::global_data::get().dtypes->name(vtype), \ " is not supported yet"); \ - return ccl_status_invalid_arguments; \ + return ccl::status::invalid_arguments; \ } \ } while (0) @@ -141,7 +130,7 @@ sa_handler->size_per_rank = \ static_cast(sched->alloc_buffer(sizeof(size_t) * comm_size).get_ptr()); \ \ - for (size_t i = 0; i < comm_size; i++) \ + for (int i = 0; i < comm_size; i++) \ sa_handler->size_per_rank[i] = sizeof(size_t); \ \ sa_handler->send_ibuf = send_ind_buf.get_ptr(); \ @@ -172,7 +161,7 @@ param_nnz.recv_buf = ccl_buffer(sa_handler->recv_counts, sizeof(size_t) * comm_size); \ param_nnz.send_count = sizeof(size_t); \ param_nnz.recv_counts = sa_handler->size_per_rank; \ - param_nnz.dtype = ccl_datatype_char; \ + param_nnz.dtype = ccl_datatype_int8; \ param_nnz.comm = comm; \ \ entry_factory::make_entry(sched, param_nnz); \ @@ -180,7 +169,7 @@ } while (0) template -typename std::enable_if::value, vtype>::type get_mask( +typename std::enable_if::value, vtype>::type get_mask( ccl::reduction op) { switch (op) { case ccl::reduction::sum: return 0; @@ -189,22 +178,22 @@ typename std::enable_if::value, vtype>::type get case ccl::reduction::max: return std::numeric_limits::min(); case ccl::reduction::custom: CCL_FATAL("custom reduction is not supported for sparse_allreduce/mask algorithm"); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; default: return 0; } } template -typename std::enable_if::value, vtype>::type get_mask( +typename std::enable_if::value, vtype>::type get_mask( ccl::reduction op) { switch (op) { - case ccl::reduction::sum: return 0; - case ccl::reduction::prod: return CCL_BF16_ONE; - case ccl::reduction::min: return CCL_BF16_MAX; - case ccl::reduction::max: return CCL_BF16_MIN; + case ccl::reduction::sum: return ccl::bfloat16(0); + case ccl::reduction::prod: return ccl::bfloat16(CCL_BF16_ONE); + case ccl::reduction::min: return ccl::bfloat16(CCL_BF16_MAX); + case ccl::reduction::max: return ccl::bfloat16(CCL_BF16_MIN); case ccl::reduction::custom: CCL_FATAL("custom reduction is not supported for sparse_allreduce/mask algorithm"); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; default: return 0; } } @@ -300,7 +289,7 @@ void sparse_coalesce(ccl_sparse_allreduce_handler* sah) { } template -ccl_status_t sparse_reduce_ring(const void* ctx) { +ccl::status sparse_reduce_ring(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; /* Having received the msg we should prepare it for further send operation to the next neighbour. @@ -371,11 +360,11 @@ ccl_status_t sparse_reduce_ring(const void* ctx) { ccl_comp_copy(snd_i, buf_i.data(), sa_handler->itype_size * sa_handler->dst_count[0], - ccl_datatype_char); + ccl_datatype_int8); ccl_comp_copy(snd_v, buf_v.data(), sa_handler->vtype_size * sa_handler->dst_count[1], - ccl_datatype_char); + ccl_datatype_int8); size_t idx_offset = 0; for (auto id : unique_indices_ids) { @@ -409,13 +398,13 @@ ccl_status_t sparse_reduce_ring(const void* ctx) { ccl_comp_copy(buf_i.data(), (i_type*)(sa_handler->dst_buf), sa_handler->itype_size * merge_idx_len, - ccl_datatype_char); + ccl_datatype_int8); ccl_comp_copy( buf_v.data(), (v_type*)((char*)(sa_handler->dst_buf) + sa_handler->itype_size * merge_idx_len), sa_handler->vtype_size * merge_idx_len * sa_handler->val_dim_cnt, - ccl_datatype_char); + ccl_datatype_int8); sa_handler->dst_count[0] = merge_idx_len; sa_handler->dst_count[1] = merge_idx_len * sa_handler->val_dim_cnt; @@ -425,15 +414,15 @@ ccl_status_t sparse_reduce_ring(const void* ctx) { ccl_comp_copy(sa_handler->recv_buf, sa_handler->send_tmp_buf, idx_size + sa_handler->send_count[1] * sa_handler->vtype_size, - ccl_datatype_char); + ccl_datatype_int8); sa_handler->iter++; - return ccl_status_success; + return ccl::status::success; } template -ccl_status_t sparse_prepare_result_ring(const void* ctx) { +ccl::status sparse_prepare_result_ring(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; /* data should be returned as sorted in the result buffer */ @@ -458,25 +447,25 @@ ccl_status_t sparse_prepare_result_ring(const void* ctx) { sa_handler->iv_map->clear(); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_send_count_ring(const void* ctx, void* field_ptr) { +ccl::status sparse_get_send_count_ring(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; size_t* cnt_ptr = (size_t*)field_ptr; *cnt_ptr = sa_handler->send_count[0] * (sa_handler->itype_size + sa_handler->val_dim_cnt * sa_handler->vtype_size); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_send_buf_ring(const void* ctx, void* field_ptr) { +ccl::status sparse_get_send_buf_ring(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; buf_ptr->set(sa_handler->send_tmp_buf); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_recv_count_ring(const void* ctx, void* field_ptr) { +ccl::status sparse_get_recv_count_ring(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; size_t* cnt_ptr = (size_t*)field_ptr; @@ -485,21 +474,21 @@ ccl_status_t sparse_get_recv_count_ring(const void* ctx, void* field_ptr) { sa_handler->comm_size]; *cnt_ptr = nnz * (sa_handler->itype_size + sa_handler->val_dim_cnt * sa_handler->vtype_size); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_recv_buf_ring(const void* ctx, void* field_ptr) { +ccl::status sparse_get_recv_buf_ring(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; buf_ptr->set(sa_handler->recv_buf); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_set_max_buf_size_ring(const void* ctx) { +ccl::status sparse_set_max_buf_size_ring(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; size_t max_nnz = sa_handler->recv_counts[0]; - for (size_t i = 1; i < sa_handler->comm_size; i++) { + for (int i = 1; i < sa_handler->comm_size; i++) { if (max_nnz < sa_handler->recv_counts[i]) { max_nnz = sa_handler->recv_counts[i]; } @@ -514,11 +503,11 @@ ccl_status_t sparse_set_max_buf_size_ring(const void* ctx) { sa_handler->send_tmp_buf, sa_handler->dst_buf, sa_handler->dst_count[0] * common_size_part); sa_handler->recv_buf = sa_handler->sched->alloc_buffer(max_size).get_ptr(); - return ccl_status_success; + return ccl::status::success; } template -ccl_status_t sparse_coalesce_ring(const void* ctx) { +ccl::status sparse_coalesce_ring(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; sparse_coalesce(sa_handler); @@ -530,27 +519,27 @@ ccl_status_t sparse_coalesce_ring(const void* ctx) { CCL_MEMCPY(&sa_handler->dst_count, &sa_handler->send_count, sizeof(size_t) * 2); CCL_SPARSE_ALLREDUCE_IF_SINGLE_RANK(); - return ccl_status_success; + return ccl::status::success; } template -ccl_status_t ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, - ccl_buffer send_ind_buf, - size_t send_ind_count, - ccl_buffer send_val_buf, - size_t send_val_count, - void** recv_ind_buf, - size_t* recv_ind_count, - void** recv_val_buf, - size_t* recv_val_count, - const ccl_datatype& index_dtype, - const ccl_datatype& value_dtype, - ccl::reduction op, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; - - size_t comm_size = comm->size(); - size_t rank = comm->rank(); +ccl::status ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, + ccl_buffer send_ind_buf, + size_t send_ind_count, + ccl_buffer send_val_buf, + size_t send_val_count, + void** recv_ind_buf, + size_t* recv_ind_count, + void** recv_val_buf, + size_t* recv_val_count, + const ccl_datatype& index_dtype, + const ccl_datatype& value_dtype, + ccl::reduction op, + ccl_comm* comm) { + ccl::status status = ccl::status::success; + + int comm_size = comm->size(); + int rank = comm->rank(); /* get data type sizes */ size_t vtype_size = sizeof(v_type); @@ -570,10 +559,10 @@ ccl_status_t ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, /* send from left to right (ring)*/ /* receive from the left neighbour */ - size_t recv_from = (rank - 1 + comm_size) % comm_size; + int recv_from = (rank - 1 + comm_size) % comm_size; /* send to the right neighbour */ - size_t send_to = (rank + 1) % comm_size; + int send_to = (rank + 1) % comm_size; sa_handler->recv_from = recv_from; sa_handler->iter = 0; @@ -591,16 +580,16 @@ ccl_status_t ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, entry_factory::make_entry(sched, sparse_set_max_buf_size_ring, sa_handler); sched->add_barrier(); - for (size_t i = 0; i < comm_size - 1; i++) { + for (int i = 0; i < comm_size - 1; i++) { /* send local data to the right neighbour */ send_entry* se = entry_factory::make_entry( - sched, ccl_buffer(), 0, ccl_datatype_char, send_to, comm); + sched, ccl_buffer(), 0, ccl_datatype_int8, send_to, comm); se->set_field_fn(sparse_get_send_buf_ring, sa_handler); se->set_field_fn(sparse_get_send_count_ring, sa_handler); /* receive data from the left neighbour */ recv_entry* re = entry_factory::make_entry( - sched, ccl_buffer(), 0, ccl_datatype_char, recv_from, comm); + sched, ccl_buffer(), 0, ccl_datatype_int8, recv_from, comm); re->set_field_fn(sparse_get_recv_buf_ring, sa_handler); re->set_field_fn(sparse_get_recv_count_ring, sa_handler); sched->add_barrier(); @@ -621,7 +610,7 @@ ccl_status_t ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, } template -ccl_status_t sparse_create_matrix_mask(const void* ctx) { +ccl::status sparse_create_matrix_mask(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; LOG_TRACE("sa_handler: ", sa_handler, @@ -674,7 +663,7 @@ ccl_status_t sparse_create_matrix_mask(const void* ctx) { ccl_comp_copy(matrix, (char*)sa_handler->dst_buf + idx_cnt * sa_handler->itype_size, matrix_size, - ccl_datatype_char); + ccl_datatype_int8); CCL_FREE(matrix); sa_handler->iv_map->clear(); @@ -686,27 +675,27 @@ ccl_status_t sparse_create_matrix_mask(const void* ctx) { *sa_handler->recv_ibuf = sa_handler->dst_buf; *sa_handler->recv_vbuf = ((char*)sa_handler->dst_buf + sa_handler->itype_size * idx_cnt); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_allreduce_buf_mask(const void* ctx, void* field_ptr) { +ccl::status sparse_get_allreduce_buf_mask(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; buf_ptr->set(*sa_handler->recv_vbuf); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_allreduce_count_mask(const void* ctx, void* field_ptr) { +ccl::status sparse_get_allreduce_count_mask(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; size_t* cnt_ptr = (size_t*)field_ptr; *cnt_ptr = *sa_handler->recv_vcount; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_nnz_per_rank_mask(const void* ctx) { +ccl::status sparse_nnz_per_rank_mask(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; sa_handler->recv_buf_count = 0; - for (size_t i = 0; i < sa_handler->comm_size; i++) { + for (int i = 0; i < sa_handler->comm_size; i++) { sa_handler->recv_buf_count += sa_handler->recv_counts[i]; } @@ -714,32 +703,32 @@ ccl_status_t sparse_nnz_per_rank_mask(const void* ctx) { sa_handler->sched->alloc_buffer(sa_handler->itype_size * sa_handler->recv_buf_count) .get_ptr(); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_allgatherv_buf_mask(const void* ctx, void* field_ptr) { +ccl::status sparse_get_allgatherv_buf_mask(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; buf_ptr->set(sa_handler->recv_buf); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_send_buf_mask(const void* ctx, void* field_ptr) { +ccl::status sparse_get_send_buf_mask(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; buf_ptr->set(sa_handler->dst_buf); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_send_count_mask(const void* ctx, void* field_ptr) { +ccl::status sparse_get_send_count_mask(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; size_t* count = (size_t*)field_ptr; *count = sa_handler->dst_count[0]; - return ccl_status_success; + return ccl::status::success; } template -ccl_status_t sparse_coalesce_mask(const void* ctx) { +ccl::status sparse_coalesce_mask(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; sparse_coalesce(sa_handler); @@ -750,26 +739,26 @@ ccl_status_t sparse_coalesce_mask(const void* ctx) { sa_handler->dst_count[1] = iv_map_cnt * sa_handler->val_dim_cnt; CCL_SPARSE_ALLREDUCE_IF_SINGLE_RANK(); - return ccl_status_success; + return ccl::status::success; } template -ccl_status_t ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, - ccl_buffer send_ind_buf, - size_t send_ind_count, - ccl_buffer send_val_buf, - size_t send_val_count, - void** recv_ind_buf, - size_t* recv_ind_count, - void** recv_val_buf, - size_t* recv_val_count, - const ccl_datatype& index_dtype, - const ccl_datatype& value_dtype, - ccl::reduction op, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; - - size_t comm_size = comm->size(); +ccl::status ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, + ccl_buffer send_ind_buf, + size_t send_ind_count, + ccl_buffer send_val_buf, + size_t send_val_count, + void** recv_ind_buf, + size_t* recv_ind_count, + void** recv_val_buf, + size_t* recv_val_count, + const ccl_datatype& index_dtype, + const ccl_datatype& value_dtype, + ccl::reduction op, + ccl_comm* comm) { + ccl::status status = ccl::status::success; + + int comm_size = comm->size(); /* get data type sizes */ size_t itype_size = sizeof(i_type); @@ -840,11 +829,11 @@ ccl_status_t ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, return status; } -ccl_status_t sparse_alloc_result_buf_allgatherv(const void* ctx) { +ccl::status sparse_alloc_result_buf_allgatherv(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; sa_handler->recv_buf_count = 0; - for (size_t i = 0; i < sa_handler->comm_size; i++) { + for (int i = 0; i < sa_handler->comm_size; i++) { sa_handler->recv_buf_count += sa_handler->recv_counts[i]; } @@ -883,21 +872,21 @@ ccl_status_t sparse_alloc_result_buf_allgatherv(const void* ctx) { CCL_THROW_IF_NOT(sa_handler->all_idx_buf && sa_handler->all_val_buf); - return ccl_status_success; + return ccl::status::success; } template -ccl_status_t sparse_set_v_counts_allgatherv(const void* ctx) { +ccl::status sparse_set_v_counts_allgatherv(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; size_t stride = stride_per_comm * sa_handler->comm_size; - for (size_t i = 0; i < sa_handler->comm_size; i++) { + for (int i = 0; i < sa_handler->comm_size; i++) { sa_handler->recv_counts[i + stride] = sa_handler->recv_counts[i] * sa_handler->val_dim_cnt; } - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_return_gathered_allgatherv(const void* ctx) { +ccl::status sparse_return_gathered_allgatherv(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; *sa_handler->recv_icount = sa_handler->recv_buf_count; *sa_handler->recv_vcount = sa_handler->recv_buf_count * sa_handler->val_dim_cnt; @@ -905,11 +894,11 @@ ccl_status_t sparse_return_gathered_allgatherv(const void* ctx) { *sa_handler->recv_ibuf = sa_handler->all_idx_buf; *sa_handler->recv_vbuf = sa_handler->all_val_buf; - return ccl_status_success; + return ccl::status::success; } template -ccl_status_t sparse_reduce_gathered_allgatherv(const void* ctx) { +ccl::status sparse_reduce_gathered_allgatherv(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; i_type* indices = static_cast(sa_handler->all_idx_buf); v_type* values = static_cast(sa_handler->all_val_buf); @@ -997,39 +986,39 @@ ccl_status_t sparse_reduce_gathered_allgatherv(const void* ctx) { *sa_handler->recv_ibuf = i_recv; *sa_handler->recv_vbuf = v_recv; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_i_recv_allgatherv(const void* ctx, void* field_ptr) { +ccl::status sparse_get_i_recv_allgatherv(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; buf_ptr->set(sa_handler->all_idx_buf); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_i_send_allgatherv(const void* ctx, void* field_ptr) { +ccl::status sparse_get_i_send_allgatherv(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; buf_ptr->set(sa_handler->dst_ibuf); - return ccl_status_success; + return ccl::status::success; } template -ccl_status_t sparse_get_send_count_allgatherv(const void* ctx, void* field_ptr) { +ccl::status sparse_get_send_count_allgatherv(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; size_t* send_buf_count = (size_t*)field_ptr; *send_buf_count = sa_handler->send_count[send_count_src_index]; - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_v_recv_allgatherv(const void* ctx, void* field_ptr) { +ccl::status sparse_get_v_recv_allgatherv(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; buf_ptr->set(sa_handler->all_val_buf); - return ccl_status_success; + return ccl::status::success; } -ccl_status_t sparse_get_v_send_allgatherv(const void* ctx, void* field_ptr) { +ccl::status sparse_get_v_send_allgatherv(const void* ctx, void* field_ptr) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; ccl_buffer* buf_ptr = (ccl_buffer*)field_ptr; if (sa_handler->sched->coll_attr.sparse_coalesce_mode == ccl::sparse_coalesce_mode::disable) { @@ -1039,11 +1028,11 @@ ccl_status_t sparse_get_v_send_allgatherv(const void* ctx, void* field_ptr) { buf_ptr->set(sa_handler->dst_vbuf); } - return ccl_status_success; + return ccl::status::success; } template -ccl_status_t sparse_coalesce_allgatherv(const void* ctx) { +ccl::status sparse_coalesce_allgatherv(const void* ctx) { ccl_sparse_allreduce_handler* sa_handler = (ccl_sparse_allreduce_handler*)ctx; sparse_coalesce(sa_handler); @@ -1060,26 +1049,26 @@ ccl_status_t sparse_coalesce_allgatherv(const void* ctx) { *sa_handler->recv_vbuf = sa_handler->dst_vbuf; } - return ccl_status_success; + return ccl::status::success; } template -ccl_status_t ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, - ccl_buffer send_ind_buf, - size_t send_ind_count, - ccl_buffer send_val_buf, - size_t send_val_count, - void** recv_ind_buf, - size_t* recv_ind_count, - void** recv_val_buf, - size_t* recv_val_count, - const ccl_datatype& index_dtype, - const ccl_datatype& value_dtype, - ccl::reduction op, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; - - size_t comm_size = comm->size(); +ccl::status ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, + ccl_buffer send_ind_buf, + size_t send_ind_count, + ccl_buffer send_val_buf, + size_t send_val_count, + void** recv_ind_buf, + size_t* recv_ind_count, + void** recv_val_buf, + size_t* recv_val_count, + const ccl_datatype& index_dtype, + const ccl_datatype& value_dtype, + ccl::reduction op, + ccl_comm* comm) { + ccl::status status = ccl::status::success; + + int comm_size = comm->size(); /* get data type sizes */ size_t vtype_size = sizeof(v_type); diff --git a/src/coll/algorithms/sparse_allreduce/sparse_handler.hpp b/src/coll/algorithms/sparse_allreduce/sparse_handler.hpp index a37bcc04b..26e972491 100644 --- a/src/coll/algorithms/sparse_allreduce/sparse_handler.hpp +++ b/src/coll/algorithms/sparse_allreduce/sparse_handler.hpp @@ -24,9 +24,9 @@ struct ccl_sparse_allreduce_handler { size_t recv_buf_count; size_t itype_size; size_t vtype_size; - size_t comm_size; + int comm_size; size_t buf_size; - size_t recv_from; + int recv_from; size_t iter; /*iteration within ring algorithm*/ size_t send_count[2]; diff --git a/src/coll/ccl_allgather_op_attr.hpp b/src/coll/ccl_allgather_op_attr.hpp index 6de415fb6..b9c43842b 100644 --- a/src/coll/ccl_allgather_op_attr.hpp +++ b/src/coll/ccl_allgather_op_attr.hpp @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" #include "coll/coll_common_attributes.hpp" namespace ccl { @@ -27,9 +27,8 @@ class ccl_allgatherv_attr_impl_t : public ccl_operation_attr_impl_t { ccl_allgatherv_attr_impl_t(const base_t& base); ccl_allgatherv_attr_impl_t( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); ccl_allgatherv_attr_impl_t(const ccl_allgatherv_attr_impl_t& src); private: diff --git a/src/coll/ccl_allreduce_op_attr.hpp b/src/coll/ccl_allreduce_op_attr.hpp index deb72c14a..911861734 100644 --- a/src/coll/ccl_allreduce_op_attr.hpp +++ b/src/coll/ccl_allreduce_op_attr.hpp @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" #include "coll/coll_common_attributes.hpp" namespace ccl { @@ -26,12 +26,11 @@ class ccl_allreduce_attr_impl_t : public ccl_operation_attr_impl_t { using base_t = ccl_operation_attr_impl_t; ccl_allreduce_attr_impl_t( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); using reduction_fn_traits_t = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; typename reduction_fn_traits_t::return_type set_attribute_value( typename reduction_fn_traits_t::type val, const reduction_fn_traits_t& t); diff --git a/src/coll/ccl_alltoall_op_attr.hpp b/src/coll/ccl_alltoall_op_attr.hpp index cf575c8de..d1b58587d 100644 --- a/src/coll/ccl_alltoall_op_attr.hpp +++ b/src/coll/ccl_alltoall_op_attr.hpp @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" #include "coll/coll_common_attributes.hpp" namespace ccl { @@ -26,9 +26,8 @@ class ccl_alltoall_attr_impl_t : public ccl_operation_attr_impl_t { using base_t = ccl_operation_attr_impl_t; ccl_alltoall_attr_impl_t( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); }; } // namespace ccl diff --git a/src/coll/ccl_alltoallv_op_attr.hpp b/src/coll/ccl_alltoallv_op_attr.hpp index 019a2745f..b9a855bb2 100644 --- a/src/coll/ccl_alltoallv_op_attr.hpp +++ b/src/coll/ccl_alltoallv_op_attr.hpp @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" #include "coll/coll_common_attributes.hpp" namespace ccl { @@ -26,9 +26,8 @@ class ccl_alltoallv_attr_impl_t : public ccl_operation_attr_impl_t { using base_t = ccl_operation_attr_impl_t; ccl_alltoallv_attr_impl_t( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); }; } // namespace ccl diff --git a/src/coll/ccl_barrier_attr.hpp b/src/coll/ccl_barrier_attr.hpp index f5657711f..a3c7849a1 100644 --- a/src/coll/ccl_barrier_attr.hpp +++ b/src/coll/ccl_barrier_attr.hpp @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" #include "coll/coll_common_attributes.hpp" namespace ccl { @@ -26,8 +26,7 @@ class ccl_barrier_attr_impl_t : public ccl_operation_attr_impl_t { using base_t = ccl_operation_attr_impl_t; ccl_barrier_attr_impl_t( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); }; } // namespace ccl diff --git a/src/coll/ccl_bcast_op_attr.hpp b/src/coll/ccl_bcast_op_attr.hpp index 1a733a596..78db47ef1 100644 --- a/src/coll/ccl_bcast_op_attr.hpp +++ b/src/coll/ccl_bcast_op_attr.hpp @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" #include "coll/coll_common_attributes.hpp" namespace ccl { @@ -26,9 +26,8 @@ class ccl_broadcast_attr_impl_t : public ccl_operation_attr_impl_t { using base_t = ccl_operation_attr_impl_t; ccl_broadcast_attr_impl_t( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); }; } // namespace ccl diff --git a/src/coll/ccl_reduce_op_attr.hpp b/src/coll/ccl_reduce_op_attr.hpp index 74a1dc767..825367ddb 100644 --- a/src/coll/ccl_reduce_op_attr.hpp +++ b/src/coll/ccl_reduce_op_attr.hpp @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" #include "coll/coll_common_attributes.hpp" namespace ccl { @@ -26,12 +26,11 @@ class ccl_reduce_attr_impl_t : public ccl_operation_attr_impl_t { using base_t = ccl_operation_attr_impl_t; ccl_reduce_attr_impl_t( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); using reduction_fn_traits_t = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; typename reduction_fn_traits_t::return_type set_attribute_value( typename reduction_fn_traits_t::type val, const reduction_fn_traits_t& t); diff --git a/src/coll/ccl_reduce_scatter_op_attr.hpp b/src/coll/ccl_reduce_scatter_op_attr.hpp index 7590dfe46..5f40c76ab 100644 --- a/src/coll/ccl_reduce_scatter_op_attr.hpp +++ b/src/coll/ccl_reduce_scatter_op_attr.hpp @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" #include "coll/coll_common_attributes.hpp" namespace ccl { @@ -27,13 +27,12 @@ class ccl_reduce_scatter_attr_impl_t : public ccl_operation_attr_impl_t { using base_t = ccl_operation_attr_impl_t; ccl_reduce_scatter_attr_impl_t( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); using reduction_fn_traits_t = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; typename reduction_fn_traits_t::return_type set_attribute_value( typename reduction_fn_traits_t::type val, const reduction_fn_traits_t& t); diff --git a/src/coll/ccl_sparse_allreduce_op_attr.hpp b/src/coll/ccl_sparse_allreduce_op_attr.hpp index 4abc53a3e..8ba49dae8 100644 --- a/src/coll/ccl_sparse_allreduce_op_attr.hpp +++ b/src/coll/ccl_sparse_allreduce_op_attr.hpp @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" #include "coll/coll_common_attributes.hpp" namespace ccl { @@ -26,13 +26,12 @@ class ccl_sparse_allreduce_attr_impl_t : public ccl_operation_attr_impl_t { using base_t = ccl_operation_attr_impl_t; ccl_sparse_allreduce_attr_impl_t( - const typename details::ccl_api_type_attr_traits::type& - version); + const typename detail::ccl_api_type_attr_traits::type& version); using sparse_allreduce_completion_fn_traits = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; typename sparse_allreduce_completion_fn_traits::return_type set_attribute_value( typename sparse_allreduce_completion_fn_traits::type val, const sparse_allreduce_completion_fn_traits& t); @@ -40,8 +39,8 @@ class ccl_sparse_allreduce_attr_impl_t : public ccl_operation_attr_impl_t { const sparse_allreduce_completion_fn_traits& id) const; using sparse_allreduce_alloc_fn_traits = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; typename sparse_allreduce_alloc_fn_traits::return_type set_attribute_value( typename sparse_allreduce_alloc_fn_traits::type val, const sparse_allreduce_alloc_fn_traits& t); @@ -49,8 +48,8 @@ class ccl_sparse_allreduce_attr_impl_t : public ccl_operation_attr_impl_t { const sparse_allreduce_alloc_fn_traits& id) const; using sparse_allreduce_fn_ctx_traits = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; typename sparse_allreduce_fn_ctx_traits::return_type set_attribute_value( typename sparse_allreduce_fn_ctx_traits::type val, const sparse_allreduce_fn_ctx_traits& t); @@ -58,8 +57,8 @@ class ccl_sparse_allreduce_attr_impl_t : public ccl_operation_attr_impl_t { const sparse_allreduce_fn_ctx_traits& id) const; using sparse_coalesce_mode_traits = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; typename sparse_coalesce_mode_traits::return_type set_attribute_value( typename sparse_coalesce_mode_traits::type val, const sparse_coalesce_mode_traits& t); diff --git a/src/coll/coll.cpp b/src/coll/coll.cpp index 0d7cb03a6..86e202686 100644 --- a/src/coll/coll.cpp +++ b/src/coll/coll.cpp @@ -13,27 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_aliases.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/aliases.hpp" -#include "oneapi/ccl/ccl_type_traits.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" +#include "oneapi/ccl/type_traits.hpp" +#include "oneapi/ccl/types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_coll_attr.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" +#include "oneapi/ccl/coll_attr.hpp" -#include "oneapi/ccl/ccl_comm_split_attr_ids.hpp" -#include "oneapi/ccl/ccl_comm_split_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_comm_split_attr.hpp" +#include "oneapi/ccl/comm_split_attr_ids.hpp" +#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_split_attr.hpp" -#include "common/event/event_internal/event_internal_attr_ids.hpp" -#include "common/event/event_internal/event_internal_attr_ids_traits.hpp" -#include "common/event/event_internal/event_internal.hpp" - -#include "oneapi/ccl/ccl_stream_attr_ids.hpp" -#include "oneapi/ccl/ccl_stream_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_stream.hpp" +#include "oneapi/ccl/stream_attr_ids.hpp" +#include "oneapi/ccl/stream_attr_ids_traits.hpp" +#include "oneapi/ccl/stream.hpp" #include "common/request/request.hpp" @@ -61,38 +57,13 @@ #include "unordered_coll/unordered_coll.hpp" #define COPY_COMMON_OP_ATTRS(from, to) \ - to->prologue_fn = from.get().get(); \ - to->epilogue_fn = from.get().get(); \ + to->prologue_fn = nullptr; /*from.get().get();*/ \ + to->epilogue_fn = nullptr; /*from.get().get();*/ \ to->priority = from.get(); \ to->synchronous = from.get(); \ to->to_cache = from.get(); \ to->match_id = from.get(); -ccl_coll_attr::ccl_coll_attr(const ccl_coll_attr_t* attr) { - *this = attr ?: ccl::global_data::get().default_coll_attr.get(); -} - -ccl_coll_attr& ccl_coll_attr::operator=(const ccl_coll_attr_t* attr) { - prologue_fn = attr->prologue_fn; - epilogue_fn = attr->epilogue_fn; - reduction_fn = attr->reduction_fn; - priority = attr->priority; - synchronous = attr->synchronous; - to_cache = attr->to_cache && attr->match_id && attr->match_id[0]; - vector_buf = attr->vector_buf; - match_id = (attr->match_id ? attr->match_id : ""); - - sparse_allreduce_completion_fn = attr->sparse_allreduce_completion_fn; - sparse_allreduce_alloc_fn = attr->sparse_allreduce_alloc_fn; - sparse_allreduce_fn_ctx = attr->sparse_allreduce_fn_ctx; - sparse_coalesce_mode = attr->sparse_coalesce_mode; - - if (to_cache != attr->to_cache) - LOG_INFO("collective caching is requested but no match_id is provided, disable caching"); - - return *this; -} - //TODO temporary solution for type convertation, ccl_coll_attr would be depreacated ccl_coll_attr::ccl_coll_attr(const ccl::allgatherv_attr& attr) { COPY_COMMON_OP_ATTRS(attr, this); @@ -149,7 +120,8 @@ static ccl_request* ccl_coll_create(ccl_coll_param& param, const ccl_coll_attr& bool postpone_schedule = false; if (ccl::global_data::env().enable_unordered_coll) { if (!attr.match_id.empty()) { - auto comm = param.comm->unordered_coll_manager->get_comm(std::string(attr.match_id)).get(); + auto comm = + param.comm->unordered_coll_manager->get_comm(std::string(attr.match_id)).get(); if (!comm) { if (attr.synchronous) { CCL_THROW("unsupported collective (synchronous && unordered && !communicator)"); @@ -212,7 +184,8 @@ static ccl_request* ccl_gpu_coll_create(ccl_coll_param& param, const ccl_coll_at bool postpone_schedule = false; if (ccl::global_data::env().enable_unordered_coll) { if (!attr.match_id.empty()) { - auto comm = param.comm->unordered_coll_manager->get_comm(std::string(attr.match_id)).get(); + auto comm = + param.comm->unordered_coll_manager->get_comm(std::string(attr.match_id)).get(); if (!comm) { if (attr.synchronous) { CCL_THROW("unsupported collective (synchronous && unordered && !communicator)"); @@ -267,14 +240,14 @@ static ccl_request* ccl_gpu_coll_create(ccl_coll_param& param, const ccl_coll_at return request; } -ccl_status_t ccl_coll_build_allgatherv(ccl_sched* sched, - ccl_buffer send_buf, - size_t send_count, - ccl_buffer recv_buf, - const size_t* recv_counts, - const ccl_datatype& dtype, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; +ccl::status ccl_coll_build_allgatherv(ccl_sched* sched, + ccl_buffer send_buf, + size_t send_count, + ccl_buffer recv_buf, + const size_t* recv_counts, + const ccl_datatype& dtype, + ccl_comm* comm) { + ccl::status status = ccl::status::success; ccl_selector_param param; param.ctype = ccl_coll_allgatherv; @@ -300,19 +273,19 @@ ccl_status_t ccl_coll_build_allgatherv(ccl_sched* sched, break; default: CCL_FATAL("unexpected allgatherv_algo ", ccl_coll_algorithm_to_str(algo)); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; } return status; } -ccl_status_t ccl_coll_build_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; +ccl::status ccl_coll_build_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm) { + ccl::status status = ccl::status::success; ccl_selector_param param; param.ctype = ccl_coll_allreduce; @@ -360,23 +333,23 @@ ccl_status_t ccl_coll_build_allreduce(ccl_sched* sched, break; case ccl_coll_allreduce_2d: CCL_CALL(comm->allreduce_2d_builder->build( - sched, send_buf, recv_buf, count, dtype, reduction)); + sched, send_buf, recv_buf, count, dtype, reduction)); break; default: CCL_FATAL("unexpected allreduce_algo ", ccl_coll_algorithm_to_str(algo)); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; } return status; } -ccl_status_t ccl_coll_build_alltoall(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; +ccl::status ccl_coll_build_alltoall(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl_comm* comm) { + ccl::status status = ccl::status::success; ccl_selector_param param; param.ctype = ccl_coll_alltoall; @@ -392,20 +365,20 @@ ccl_status_t ccl_coll_build_alltoall(ccl_sched* sched, break; default: CCL_FATAL("unexpected alltoall_algo ", ccl_coll_algorithm_to_str(algo)); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; } return status; } -ccl_status_t ccl_coll_build_alltoallv(ccl_sched* sched, - ccl_buffer send_buf, - const size_t* send_counts, - ccl_buffer recv_buf, - const size_t* recv_counts, - const ccl_datatype& dtype, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; +ccl::status ccl_coll_build_alltoallv(ccl_sched* sched, + ccl_buffer send_buf, + const size_t* send_counts, + ccl_buffer recv_buf, + const size_t* recv_counts, + const ccl_datatype& dtype, + ccl_comm* comm) { + ccl::status status = ccl::status::success; ccl_selector_param param; param.ctype = ccl_coll_alltoallv; @@ -421,19 +394,19 @@ ccl_status_t ccl_coll_build_alltoallv(ccl_sched* sched, break; default: CCL_FATAL("unexpected alltoallv_algo ", ccl_coll_algorithm_to_str(algo)); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; } return status; } -ccl_status_t ccl_coll_build_barrier(ccl_sched* sched, ccl_comm* comm) { - ccl_status_t status = ccl_status_success; +ccl::status ccl_coll_build_barrier(ccl_sched* sched, ccl_comm* comm) { + ccl::status status = ccl::status::success; ccl_selector_param param; param.ctype = ccl_coll_barrier; param.count = 0; - param.dtype = ccl_datatype_char; + param.dtype = ccl_datatype_int8; param.comm = comm; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -445,19 +418,19 @@ ccl_status_t ccl_coll_build_barrier(ccl_sched* sched, ccl_comm* comm) { break; default: CCL_FATAL("unexpected barrier_algo ", ccl_coll_algorithm_to_str(algo)); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; } return status; } -ccl_status_t ccl_coll_build_bcast(ccl_sched* sched, - ccl_buffer buf, - size_t count, - const ccl_datatype& dtype, - size_t root, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; +ccl::status ccl_coll_build_bcast(ccl_sched* sched, + ccl_buffer buf, + size_t count, + const ccl_datatype& dtype, + int root, + ccl_comm* comm) { + ccl::status status = ccl::status::success; ccl_selector_param param; param.ctype = ccl_coll_bcast; @@ -492,20 +465,20 @@ ccl_status_t ccl_coll_build_bcast(ccl_sched* sched, break; default: CCL_FATAL("unexpected bcast_algo ", ccl_coll_algorithm_to_str(algo)); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; } return status; } -ccl_status_t ccl_coll_build_reduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - size_t root, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; +ccl::status ccl_coll_build_reduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + int root, + ccl_comm* comm) { + ccl::status status = ccl::status::success; ccl_selector_param param; param.ctype = ccl_coll_reduce; @@ -542,21 +515,21 @@ ccl_status_t ccl_coll_build_reduce(ccl_sched* sched, break; default: CCL_FATAL("unexpected reduce_algo ", ccl_coll_algorithm_to_str(algo)); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; } return status; } -ccl_status_t ccl_coll_build_reduce_scatter(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm, - bool from_allreduce) { - ccl_status_t status = ccl_status_success; +ccl::status ccl_coll_build_reduce_scatter(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm, + bool from_allreduce) { + ccl::status status = ccl::status::success; ccl_selector_param param; param.ctype = ccl_coll_reduce_scatter; @@ -568,51 +541,48 @@ ccl_status_t ccl_coll_build_reduce_scatter(ccl_sched* sched, switch (algo) { case ccl_coll_reduce_scatter_direct: - if (!from_allreduce) - { + if (!from_allreduce) { CCL_CALL(ccl_coll_build_direct_reduce_scatter( sched, send_buf, recv_buf, count, dtype, reduction, comm)); break; } case ccl_coll_reduce_scatter_ring: - if (from_allreduce) - { + if (from_allreduce) { CCL_CALL(ccl_coll_build_ring_reduce_scatter( sched, send_buf, recv_buf, count, dtype, reduction, comm)); } - else - { + else { CCL_CALL(ccl_coll_build_ring_reduce_scatter_block( sched, send_buf, recv_buf, count, dtype, reduction, comm)); } break; default: CCL_FATAL("unexpected reduce_scatter_algo ", ccl_coll_algorithm_to_str(algo)); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; } return status; } -ccl_status_t ccl_coll_build_sparse_allreduce(ccl_sched* sched, - ccl_buffer send_ind_buf, - size_t send_ind_count, - ccl_buffer send_val_buf, - size_t send_val_count, - void** recv_ind_buf, - size_t* recv_ind_count, - void** recv_val_buf, - size_t* recv_val_count, - const ccl_datatype& index_dtype, - const ccl_datatype& value_dtype, - ccl::reduction reduction, - ccl_comm* comm) { - ccl_status_t status = ccl_status_success; +ccl::status ccl_coll_build_sparse_allreduce(ccl_sched* sched, + ccl_buffer send_ind_buf, + size_t send_ind_count, + ccl_buffer send_val_buf, + size_t send_val_count, + void** recv_ind_buf, + size_t* recv_ind_count, + void** recv_val_buf, + size_t* recv_val_count, + const ccl_datatype& index_dtype, + const ccl_datatype& value_dtype, + ccl::reduction reduction, + ccl_comm* comm) { + ccl::status status = ccl::status::success; ccl_selector_param param; param.ctype = ccl_coll_sparse_allreduce; param.count = 0; - param.dtype = ccl_datatype_char; + param.dtype = ccl_datatype_int8; param.comm = comm; param.sparse_coalesce_mode = sched->coll_attr.sparse_coalesce_mode; param.sparse_allreduce_alloc_fn = sched->coll_attr.sparse_allreduce_alloc_fn; @@ -653,16 +623,16 @@ ccl_status_t ccl_coll_build_sparse_allreduce(ccl_sched* sched, send_ind_count, ", values count = ", send_val_count); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; } if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { /* for now all sparse_allreduce algorithms may contains direct collective entries (allreduce/allgatherv) - which should be executed in strict_start_order mode + which should be executed in strict_order mode */ - sched->strict_start_order = true; + sched->strict_order = true; } auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -692,23 +662,17 @@ ccl_status_t ccl_coll_build_sparse_allreduce(ccl_sched* sched, ccl_reduction_to_str(reduction)); switch (index_dtype.idx()) { - case ccl::datatype::int8: - CCL_SPARSE_ALLREDUCE_SELECT_V_DTYPE(char, value_dtype, algo); - break; case ccl::datatype::int32: - CCL_SPARSE_ALLREDUCE_SELECT_V_DTYPE(int, value_dtype, algo); + CCL_SPARSE_ALLREDUCE_SELECT_V_DTYPE(int32_t, value_dtype, algo); break; case ccl::datatype::int64: CCL_SPARSE_ALLREDUCE_SELECT_V_DTYPE(int64_t, value_dtype, algo); break; - case ccl::datatype::uint64: - CCL_SPARSE_ALLREDUCE_SELECT_V_DTYPE(uint64_t, value_dtype, algo); - break; default: CCL_FATAL("index datatype ", ccl::global_data::get().dtypes->name(index_dtype), " is not supported yet"); - return ccl_status_invalid_arguments; + return ccl::status::invalid_arguments; } return status; @@ -838,7 +802,7 @@ void ccl_barrier_impl(ccl_comm* comm, const ccl_stream* stream) { ccl_coll_param param{}; param.ctype = ccl_coll_barrier; - param.dtype = ccl_datatype_char; + param.dtype = ccl_datatype_int8; param.stream = stream; param.comm = comm; @@ -858,7 +822,7 @@ void ccl_barrier_impl(ccl_comm* comm, const ccl_stream* stream) { ccl_request* ccl_broadcast_impl(void* buf, size_t count, ccl::datatype dtype, - size_t root, + int root, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream) { @@ -882,7 +846,7 @@ ccl_request* ccl_reduce_impl(const void* send_buf, size_t count, ccl::datatype dtype, ccl::reduction reduction, - size_t root, + int root, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream) { diff --git a/src/coll/coll.hpp b/src/coll/coll.hpp index aca729971..69a8cc7ff 100644 --- a/src/coll/coll.hpp +++ b/src/coll/coll.hpp @@ -24,80 +24,82 @@ #include "coll/coll_common_attributes.hpp" +#include "internal_types.hpp" + class ccl_sched; class ccl_request; -ccl_status_t ccl_coll_build_allgatherv(ccl_sched* sched, - ccl_buffer send_buf, - size_t send_count, - ccl_buffer recv_buf, - const size_t* recv_counts, - const ccl_datatype& dtype, - ccl_comm* comm); - -ccl_status_t ccl_coll_build_allreduce(ccl_sched* sched, +ccl::status ccl_coll_build_allgatherv(ccl_sched* sched, ccl_buffer send_buf, + size_t send_count, ccl_buffer recv_buf, - size_t count, + const size_t* recv_counts, const ccl_datatype& dtype, - ccl::reduction reduction, ccl_comm* comm); -ccl_status_t ccl_coll_build_alltoall(ccl_sched* sched, +ccl::status ccl_coll_build_allreduce(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, size_t count, const ccl_datatype& dtype, + ccl::reduction reduction, ccl_comm* comm); -ccl_status_t ccl_coll_build_alltoallv(ccl_sched* sched, - ccl_buffer send_buf, - const size_t* send_counts, - ccl_buffer recv_buf, - const size_t* recv_counts, - const ccl_datatype& dtype, - ccl_comm* comm); +ccl::status ccl_coll_build_alltoall(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl_comm* comm); + +ccl::status ccl_coll_build_alltoallv(ccl_sched* sched, + ccl_buffer send_buf, + const size_t* send_counts, + ccl_buffer recv_buf, + const size_t* recv_counts, + const ccl_datatype& dtype, + ccl_comm* comm); + +ccl::status ccl_coll_build_barrier(ccl_sched* sched, ccl_comm* comm); -ccl_status_t ccl_coll_build_barrier(ccl_sched* sched, ccl_comm* comm); +ccl::status ccl_coll_build_bcast(ccl_sched* sched, + ccl_buffer buf, + size_t count, + const ccl_datatype& dtype, + int root, + ccl_comm* comm); -ccl_status_t ccl_coll_build_bcast(ccl_sched* sched, - ccl_buffer buf, +ccl::status ccl_coll_build_reduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, size_t count, const ccl_datatype& dtype, - size_t root, + ccl::reduction reduction, + int root, ccl_comm* comm); -ccl_status_t ccl_coll_build_reduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - size_t root, - ccl_comm* comm); - -ccl_status_t ccl_coll_build_reduce_scatter(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm, - bool from_allreduce = false); - -ccl_status_t ccl_coll_build_sparse_allreduce(ccl_sched* sched, - ccl_buffer send_ind_buf, - size_t send_ind_count, - ccl_buffer send_val_buf, - size_t send_val_count, - void** recv_ind_buf, - size_t* recv_ind_count, - void** recv_val_buf, - size_t* recv_val_count, - const ccl_datatype& index_dtype, - const ccl_datatype& value_dtype, - ccl::reduction reduction, - ccl_comm* comm); +ccl::status ccl_coll_build_reduce_scatter(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm, + bool from_allreduce = false); + +ccl::status ccl_coll_build_sparse_allreduce(ccl_sched* sched, + ccl_buffer send_ind_buf, + size_t send_ind_count, + ccl_buffer send_val_buf, + size_t send_val_count, + void** recv_ind_buf, + size_t* recv_ind_count, + void** recv_val_buf, + size_t* recv_val_count, + const ccl_datatype& index_dtype, + const ccl_datatype& value_dtype, + ccl::reduction reduction, + ccl_comm* comm); ccl_request* ccl_allgatherv_impl(const void* send_buf, size_t send_count, @@ -148,7 +150,7 @@ void ccl_barrier_impl(ccl_comm* comm, const ccl_stream* stream); ccl_request* ccl_broadcast_impl(void* buf, size_t count, ccl::datatype dtype, - size_t root, + int root, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream); @@ -158,7 +160,7 @@ ccl_request* ccl_reduce_impl(const void* send_buf, size_t count, ccl::datatype dtype, ccl::reduction reduction, - size_t root, + int root, const ccl_coll_attr& attr, ccl_comm* comm, const ccl_stream* stream); diff --git a/src/coll/coll_common_attributes.cpp b/src/coll/coll_common_attributes.cpp index 943d51577..aedd96df1 100644 --- a/src/coll/coll_common_attributes.cpp +++ b/src/coll/coll_common_attributes.cpp @@ -35,36 +35,36 @@ ccl_operation_attr_impl_t::get_attribute_value(const version_traits_t& id) const return version; } -/** - * `prologue_fn` operations definitions - */ -const typename ccl_operation_attr_impl_t::prologue_fn_traits_t::return_type& -ccl_operation_attr_impl_t::get_attribute_value(const prologue_fn_traits_t& id) const { - return prologue_fn; -} +// /** +// * `prologue_fn` operations definitions +// */ +// const typename ccl_operation_attr_impl_t::prologue_fn_traits_t::return_type& +// ccl_operation_attr_impl_t::get_attribute_value(const prologue_fn_traits_t& id) const { +// return prologue_fn; +// } -typename ccl_operation_attr_impl_t::prologue_fn_traits_t::return_type -ccl_operation_attr_impl_t::set_attribute_value(typename prologue_fn_traits_t::type val, - const prologue_fn_traits_t& t) { - auto old = prologue_fn.get(); - prologue_fn = typename prologue_fn_traits_t::return_type{ val }; - return typename prologue_fn_traits_t::return_type{ old }; -} -/** - * `epilogue_fn` operations definitions - */ -const typename ccl_operation_attr_impl_t::epilogue_fn_traits_t::return_type& -ccl_operation_attr_impl_t::get_attribute_value(const epilogue_fn_traits_t& id) const { - return epilogue_fn; -} +// typename ccl_operation_attr_impl_t::prologue_fn_traits_t::return_type +// ccl_operation_attr_impl_t::set_attribute_value(typename prologue_fn_traits_t::type val, +// const prologue_fn_traits_t& t) { +// auto old = prologue_fn.get(); +// prologue_fn = typename prologue_fn_traits_t::return_type{ val }; +// return typename prologue_fn_traits_t::return_type{ old }; +// } +// /** +// * `epilogue_fn` operations definitions +// */ +// const typename ccl_operation_attr_impl_t::epilogue_fn_traits_t::return_type& +// ccl_operation_attr_impl_t::get_attribute_value(const epilogue_fn_traits_t& id) const { +// return epilogue_fn; +// } -typename ccl_operation_attr_impl_t::epilogue_fn_traits_t::return_type -ccl_operation_attr_impl_t::set_attribute_value(typename epilogue_fn_traits_t::type val, - const epilogue_fn_traits_t& t) { - auto old = epilogue_fn.get(); - epilogue_fn = typename epilogue_fn_traits_t::return_type{ val }; - return typename epilogue_fn_traits_t::return_type{ old }; -} +// typename ccl_operation_attr_impl_t::epilogue_fn_traits_t::return_type +// ccl_operation_attr_impl_t::set_attribute_value(typename epilogue_fn_traits_t::type val, +// const epilogue_fn_traits_t& t) { +// auto old = epilogue_fn.get(); +// epilogue_fn = typename epilogue_fn_traits_t::return_type{ val }; +// return typename epilogue_fn_traits_t::return_type{ old }; +// } /** * `priority` operations definitions diff --git a/src/coll/coll_common_attributes.hpp b/src/coll/coll_common_attributes.hpp index f5ce48873..36a17434a 100644 --- a/src/coll/coll_common_attributes.hpp +++ b/src/coll/coll_common_attributes.hpp @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" namespace ccl { struct ccl_operation_attr_impl_t { @@ -26,42 +26,42 @@ struct ccl_operation_attr_impl_t { * `version` operations */ using version_traits_t = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; const typename version_traits_t::return_type& get_attribute_value( const version_traits_t& id) const; typename version_traits_t::return_type set_attribute_value(typename version_traits_t::type val, const version_traits_t& t); - /** - * `prologue_fn` operations - */ - using prologue_fn_traits_t = - details::ccl_api_type_attr_traits; - const typename prologue_fn_traits_t::return_type& get_attribute_value( - const prologue_fn_traits_t& id) const; - - typename prologue_fn_traits_t::return_type set_attribute_value( - typename prologue_fn_traits_t::type val, - const prologue_fn_traits_t& t); - - /** - * `epilogue_fn` operations - */ - using epilogue_fn_traits_t = - details::ccl_api_type_attr_traits; - const typename epilogue_fn_traits_t::return_type& get_attribute_value( - const epilogue_fn_traits_t& id) const; - - typename epilogue_fn_traits_t::return_type set_attribute_value( - typename epilogue_fn_traits_t::type val, - const epilogue_fn_traits_t& t); + // /** + // * `prologue_fn` operations + // */ + // using prologue_fn_traits_t = + // detail::ccl_api_type_attr_traits; + // const typename prologue_fn_traits_t::return_type& get_attribute_value( + // const prologue_fn_traits_t& id) const; + + // typename prologue_fn_traits_t::return_type set_attribute_value( + // typename prologue_fn_traits_t::type val, + // const prologue_fn_traits_t& t); + + // /** + // * `epilogue_fn` operations + // */ + // using epilogue_fn_traits_t = + // detail::ccl_api_type_attr_traits; + // const typename epilogue_fn_traits_t::return_type& get_attribute_value( + // const epilogue_fn_traits_t& id) const; + + // typename epilogue_fn_traits_t::return_type set_attribute_value( + // typename epilogue_fn_traits_t::type val, + // const epilogue_fn_traits_t& t); /** * `priority` operations */ using priority_traits_t = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; const typename priority_traits_t::return_type& get_attribute_value( const priority_traits_t& id) const; @@ -73,7 +73,7 @@ struct ccl_operation_attr_impl_t { * `synchronous` operations */ using synchronous_traits_t = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; const typename synchronous_traits_t::return_type& get_attribute_value( const synchronous_traits_t& id) const; @@ -85,7 +85,7 @@ struct ccl_operation_attr_impl_t { * `to_cache` operations */ using to_cache_traits_t = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; const typename to_cache_traits_t::return_type& get_attribute_value( const to_cache_traits_t& id) const; @@ -97,7 +97,7 @@ struct ccl_operation_attr_impl_t { * `match_id` operations */ using match_id_traits_t = - details::ccl_api_type_attr_traits; + detail::ccl_api_type_attr_traits; const typename match_id_traits_t::return_type& get_attribute_value( const match_id_traits_t& id) const; @@ -105,8 +105,8 @@ struct ccl_operation_attr_impl_t { typename match_id_traits_t::type val, const match_id_traits_t& t); - typename ccl_operation_attr_impl_t::prologue_fn_traits_t::return_type prologue_fn{}; - typename ccl_operation_attr_impl_t::epilogue_fn_traits_t::return_type epilogue_fn{}; + // typename ccl_operation_attr_impl_t::prologue_fn_traits_t::return_type prologue_fn{}; + // typename ccl_operation_attr_impl_t::epilogue_fn_traits_t::return_type epilogue_fn{}; /* Priority for collective operation */ size_t priority = 0; diff --git a/src/coll/coll_param.hpp b/src/coll/coll_param.hpp index e4a738dd7..6ccc95a4d 100644 --- a/src/coll/coll_param.hpp +++ b/src/coll/coll_param.hpp @@ -18,29 +18,29 @@ #include "coll/algorithms/algorithms_enum.hpp" #include "common/datatype/datatype.hpp" -#include "oneapi/ccl/ccl_type_traits.hpp" -#include "oneapi/ccl/ccl_stream_attr_ids.hpp" -#include "oneapi/ccl/ccl_stream_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_stream.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids.hpp" -#include "oneapi/ccl/ccl_coll_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_coll_attr.hpp" +#include "oneapi/ccl/type_traits.hpp" +#include "oneapi/ccl/stream_attr_ids.hpp" +#include "oneapi/ccl/stream_attr_ids_traits.hpp" +#include "oneapi/ccl/stream.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" +#include "oneapi/ccl/coll_attr.hpp" class ccl_comm; #ifdef CCL_ENABLE_SYCL #include -typedef cl::sycl::buffer ccl_sycl_buffer_t; +typedef cl::sycl::buffer ccl_sycl_buffer_t; template using ccl_sycl_typed_buffer_t = cl::sycl::buffer; /* ordering should be aligned with ccl::datatype */ -using ccl_sycle_buffer_one_dim_types = std::tuple, - ccl_sycl_typed_buffer_t, +using ccl_sycle_buffer_one_dim_types = std::tuple, + ccl_sycl_typed_buffer_t, ccl_sycl_typed_buffer_t, ccl_sycl_typed_buffer_t, - ccl_sycl_typed_buffer_t, + ccl_sycl_typed_buffer_t, ccl_sycl_typed_buffer_t, ccl_sycl_typed_buffer_t, ccl_sycl_typed_buffer_t, @@ -56,8 +56,6 @@ struct ccl_coll_attr { ccl_coll_attr() = default; ccl_coll_attr(const ccl_coll_attr&) = default; ccl_coll_attr& operator=(const ccl_coll_attr&) = default; - ccl_coll_attr(const ccl_coll_attr_t* attr); - ccl_coll_attr& operator=(const ccl_coll_attr_t* attr); //TODO temporary solution for type convertation, ccl_coll_attr would be depreacated ccl_coll_attr(const ccl::allgatherv_attr& attr); @@ -112,7 +110,7 @@ struct ccl_coll_param { const size_t* recv_counts; ccl_datatype dtype; ccl::reduction reduction; - size_t root; + int root; const ccl_stream* stream; ccl_comm* comm; ccl_coll_sparse_param sparse_param; diff --git a/src/coll/selection/selector_allgatherv.cpp b/src/coll/selection/selector_allgatherv.cpp index 34ea1bf24..e22b280cb 100644 --- a/src/coll/selection/selector_allgatherv.cpp +++ b/src/coll/selection/selector_allgatherv.cpp @@ -71,7 +71,7 @@ CCL_SELECTION_DEFINE_HELPER_METHODS(ccl_coll_allgatherv_algo, ({ CCL_ASSERT(param.recv_counts); size_t count = 0; - for (size_t idx = 0; idx < param.comm->size(); idx++) { + for (int idx = 0; idx < param.comm->size(); idx++) { count += param.recv_counts[idx]; } count /= param.comm->size(); diff --git a/src/coll/selection/selector_allreduce.cpp b/src/coll/selection/selector_allreduce.cpp index 7536148e4..ae8c8ce74 100644 --- a/src/coll/selection/selector_allreduce.cpp +++ b/src/coll/selection/selector_allreduce.cpp @@ -59,10 +59,9 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; - if (algo == ccl_coll_allreduce_rabenseifner && param.count < param.comm->pof2()) + if (algo == ccl_coll_allreduce_rabenseifner && (int)param.count < param.comm->pof2()) can_use = false; - else if (algo == ccl_coll_allreduce_ring_rma && - !atl_wrapper::attr.enable_rma) + else if (algo == ccl_coll_allreduce_ring_rma && !atl_wrapper::attr.enable_rma) can_use = false; else if (algo == ccl_coll_allreduce_starlike && !(param.count / param.comm->size())) can_use = false; diff --git a/src/coll/selection/selector_alltoall.cpp b/src/coll/selection/selector_alltoall.cpp index a544a1d70..70e92a7e9 100644 --- a/src/coll/selection/selector_alltoall.cpp +++ b/src/coll/selection/selector_alltoall.cpp @@ -52,8 +52,7 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; - if (algo == ccl_coll_alltoall_direct && - (ccl::global_data::env().atl_transport == ccl_atl_ofi)) + if (algo == ccl_coll_alltoall_direct && (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; return can_use; diff --git a/src/coll/selection/selector_alltoallv.cpp b/src/coll/selection/selector_alltoallv.cpp index 65723c7cd..e09d1fcfb 100644 --- a/src/coll/selection/selector_alltoallv.cpp +++ b/src/coll/selection/selector_alltoallv.cpp @@ -53,8 +53,7 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; - if (algo == ccl_coll_alltoallv_direct && - (ccl::global_data::env().atl_transport == ccl_atl_ofi)) + if (algo == ccl_coll_alltoallv_direct && (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; return can_use; diff --git a/src/coll/selection/selector_barrier.cpp b/src/coll/selection/selector_barrier.cpp index c7f1c23e4..cf47cd7e1 100644 --- a/src/coll/selection/selector_barrier.cpp +++ b/src/coll/selection/selector_barrier.cpp @@ -43,8 +43,7 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; - if (algo == ccl_coll_barrier_direct && - (ccl::global_data::env().atl_transport == ccl_atl_ofi)) + if (algo == ccl_coll_barrier_direct && (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; return can_use; diff --git a/src/coll/selection/selector_bcast.cpp b/src/coll/selection/selector_bcast.cpp index 578b9a0fd..aecd3e985 100644 --- a/src/coll/selection/selector_bcast.cpp +++ b/src/coll/selection/selector_bcast.cpp @@ -52,7 +52,7 @@ bool ccl_algorithm_selector_helper::can_use( can_use = false; } else if (algo == ccl_coll_bcast_direct && - (ccl::global_data::env().atl_transport == ccl_atl_ofi)) + (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; return can_use; diff --git a/src/coll/selection/selector_reduce.cpp b/src/coll/selection/selector_reduce.cpp index e0276a268..68cbe0ca9 100644 --- a/src/coll/selection/selector_reduce.cpp +++ b/src/coll/selection/selector_reduce.cpp @@ -45,10 +45,10 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; - if (algo == ccl_coll_reduce_rabenseifner && param.count < param.comm->pof2()) + if (algo == ccl_coll_reduce_rabenseifner && (int)param.count < param.comm->pof2()) can_use = false; else if (algo == ccl_coll_reduce_direct && - (ccl::global_data::env().atl_transport == ccl_atl_ofi)) + (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; return can_use; diff --git a/src/coll_attr_creation_impl.hpp b/src/coll_attr_creation_impl.hpp index 9a359c4c6..a20a1847e 100644 --- a/src/coll_attr_creation_impl.hpp +++ b/src/coll_attr_creation_impl.hpp @@ -14,26 +14,26 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_coll_attr.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/coll_attr.hpp" #include "coll/coll_attributes.hpp" +#include "common/utils/version.hpp" namespace ccl { -/* TODO temporary function for UT compilation: would be part of ccl::environment in final*/ + +namespace v1 { + +/* TODO temporary function for UT compilation: would be part of ccl::detail::environment in final*/ template coll_attribute_type create_coll_attr(attr_value_pair_t&&... avps) { - ccl::library_version ret{}; - ret.major = CCL_MAJOR_VERSION; - ret.minor = CCL_MINOR_VERSION; - ret.update = CCL_UPDATE_VERSION; - ret.product_status = CCL_PRODUCT_STATUS; - ret.build_date = CCL_PRODUCT_BUILD_DATE; - ret.full = CCL_PRODUCT_FULL; - - auto coll_attr = coll_attribute_type(ret); + auto version = utils::get_library_version(); + auto coll_attr = coll_attribute_type(version); int expander[]{ (coll_attr.template set(avps.val()), 0)... }; (void)expander; return coll_attr; } + +} // namespace v1 + } // namespace ccl diff --git a/src/coll_attr_impl.hpp b/src/coll_attr_impl.hpp index 5b5c0dfe2..7005fd386 100644 --- a/src/coll_attr_impl.hpp +++ b/src/coll_attr_impl.hpp @@ -14,42 +14,44 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_coll_attr.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/coll_attr.hpp" #include "coll/coll_attributes.hpp" namespace ccl { +namespace v1 { + template -CCL_API typename details::ccl_api_type_attr_traits::return_type allgatherv_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type allgatherv_attr::set(const Value& v) { return get_impl()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); + v, detail::ccl_api_type_attr_traits{}); } template -CCL_API typename details::ccl_api_type_attr_traits::return_type allgatherv_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type allgatherv_attr::set(const Value& v) { return static_cast(get_impl().get()) - ->set_attribute_value(v, details::ccl_api_type_attr_traits{}); + ->set_attribute_value(v, detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& allgatherv_attr::get() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& allgatherv_attr::get() const { return static_cast(get_impl().get()) - ->get_attribute_value(details::ccl_api_type_attr_traits{}); + ->get_attribute_value(detail::ccl_api_type_attr_traits{}); } /** @@ -58,33 +60,33 @@ allgatherv_attr::get() const { template -CCL_API typename details::ccl_api_type_attr_traits::return_type allreduce_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type allreduce_attr::set(const Value& v) { return get_impl()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); + v, detail::ccl_api_type_attr_traits{}); } template -CCL_API typename details::ccl_api_type_attr_traits::return_type allreduce_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type allreduce_attr::set(const Value& v) { return static_cast(get_impl().get()) - ->set_attribute_value(v, details::ccl_api_type_attr_traits{}); + ->set_attribute_value(v, detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& allreduce_attr::get() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& allreduce_attr::get() const { return static_cast(get_impl().get()) - ->get_attribute_value(details::ccl_api_type_attr_traits{}); + ->get_attribute_value(detail::ccl_api_type_attr_traits{}); } /** @@ -93,33 +95,33 @@ allreduce_attr::get() const { template -CCL_API typename details::ccl_api_type_attr_traits::return_type alltoall_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type alltoall_attr::set(const Value& v) { return get_impl()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); + v, detail::ccl_api_type_attr_traits{}); } template -CCL_API typename details::ccl_api_type_attr_traits::return_type alltoall_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type alltoall_attr::set(const Value& v) { return static_cast(get_impl().get()) - ->set_attribute_value(v, details::ccl_api_type_attr_traits{}); + ->set_attribute_value(v, detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& alltoall_attr::get() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& alltoall_attr::get() const { return static_cast(get_impl().get()) - ->get_attribute_value(details::ccl_api_type_attr_traits{}); + ->get_attribute_value(detail::ccl_api_type_attr_traits{}); } /** @@ -128,33 +130,33 @@ alltoall_attr::get() const { template -CCL_API typename details::ccl_api_type_attr_traits::return_type alltoallv_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type alltoallv_attr::set(const Value& v) { return get_impl()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); + v, detail::ccl_api_type_attr_traits{}); } template -CCL_API typename details::ccl_api_type_attr_traits::return_type alltoallv_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type alltoallv_attr::set(const Value& v) { return static_cast(get_impl().get()) - ->set_attribute_value(v, details::ccl_api_type_attr_traits{}); + ->set_attribute_value(v, detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& alltoallv_attr::get() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } template -const typename details::ccl_api_type_attr_traits::return_type& +const typename detail::ccl_api_type_attr_traits::return_type& alltoallv_attr::get() const { return static_cast(get_impl().get()) - ->get_attribute_value(details::ccl_api_type_attr_traits{}); + ->get_attribute_value(detail::ccl_api_type_attr_traits{}); } /** @@ -163,33 +165,33 @@ alltoallv_attr::get() const { template -CCL_API typename details::ccl_api_type_attr_traits::return_type broadcast_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type broadcast_attr::set(const Value& v) { return get_impl()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); + v, detail::ccl_api_type_attr_traits{}); } template -CCL_API typename details::ccl_api_type_attr_traits::return_type broadcast_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type broadcast_attr::set(const Value& v) { return static_cast(get_impl().get()) - ->set_attribute_value(v, details::ccl_api_type_attr_traits{}); + ->set_attribute_value(v, detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& broadcast_attr::get() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& broadcast_attr::get() const { return static_cast(get_impl().get()) - ->get_attribute_value(details::ccl_api_type_attr_traits{}); + ->get_attribute_value(detail::ccl_api_type_attr_traits{}); } /** @@ -198,33 +200,33 @@ broadcast_attr::get() const { template -CCL_API typename details::ccl_api_type_attr_traits::return_type reduce_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type reduce_attr::set(const Value& v) { return get_impl()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); + v, detail::ccl_api_type_attr_traits{}); } template -CCL_API typename details::ccl_api_type_attr_traits::return_type reduce_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type reduce_attr::set(const Value& v) { return static_cast(get_impl().get()) - ->set_attribute_value(v, details::ccl_api_type_attr_traits{}); + ->set_attribute_value(v, detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& reduce_attr::get() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& reduce_attr::get() const { return static_cast(get_impl().get()) - ->get_attribute_value(details::ccl_api_type_attr_traits{}); + ->get_attribute_value(detail::ccl_api_type_attr_traits{}); } /** @@ -233,34 +235,34 @@ reduce_attr::get() const { template -CCL_API typename details::ccl_api_type_attr_traits::return_type reduce_scatter_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type reduce_scatter_attr::set(const Value& v) { return get_impl()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); + v, detail::ccl_api_type_attr_traits{}); } template -CCL_API typename details::ccl_api_type_attr_traits::return_type reduce_scatter_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type reduce_scatter_attr::set(const Value& v) { return static_cast(get_impl().get()) - ->set_attribute_value(v, details::ccl_api_type_attr_traits{}); + ->set_attribute_value(v, detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& reduce_scatter_attr::get() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& reduce_scatter_attr::get() const { return static_cast(get_impl().get()) - ->get_attribute_value(details::ccl_api_type_attr_traits{}); + ->get_attribute_value(detail::ccl_api_type_attr_traits{}); } /** @@ -269,34 +271,34 @@ reduce_scatter_attr::get() const { template -CCL_API typename details::ccl_api_type_attr_traits::return_type sparse_allreduce_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type sparse_allreduce_attr::set(const Value& v) { return get_impl()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); + v, detail::ccl_api_type_attr_traits{}); } template -CCL_API typename details::ccl_api_type_attr_traits::return_type sparse_allreduce_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type sparse_allreduce_attr::set(const Value& v) { return static_cast(get_impl().get()) - ->set_attribute_value(v, details::ccl_api_type_attr_traits{}); + ->set_attribute_value(v, detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& sparse_allreduce_attr::get() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& sparse_allreduce_attr::get() const { return static_cast(get_impl().get()) - ->get_attribute_value(details::ccl_api_type_attr_traits{}); + ->get_attribute_value(detail::ccl_api_type_attr_traits{}); } /** @@ -305,32 +307,35 @@ sparse_allreduce_attr::get() const { template -CCL_API typename details::ccl_api_type_attr_traits::return_type barrier_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type barrier_attr::set(const Value& v) { return get_impl()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); + v, detail::ccl_api_type_attr_traits{}); } template -CCL_API typename details::ccl_api_type_attr_traits::return_type barrier_attr::set(const Value& v) +CCL_API typename detail::ccl_api_type_attr_traits::return_type barrier_attr::set(const Value& v) { return get_impl().get()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); + v, detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& barrier_attr::get() const { return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } template -CCL_API const typename details::ccl_api_type_attr_traits::return_type& +CCL_API const typename detail::ccl_api_type_attr_traits::return_type& barrier_attr::get() const { return get_impl().get()->get_attribute_value( - details::ccl_api_type_attr_traits{}); + detail::ccl_api_type_attr_traits{}); } + +} // namespace v1 + } // namespace ccl diff --git a/src/comm_attr_impl.hpp b/src/comm_attr_impl.hpp new file mode 100644 index 000000000..0f72e8007 --- /dev/null +++ b/src/comm_attr_impl.hpp @@ -0,0 +1,47 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/comm_attr.hpp" + +namespace ccl { + +namespace v1 { + +/** + * comm_attr attributes definition + */ +template +CCL_API Value comm_attr::set(const Value& v) { + return get_impl()->set_attribute_value( + v, detail::ccl_api_type_attr_traits{}); +} + +template +CCL_API const typename detail::ccl_api_type_attr_traits::type& +comm_attr::get() const { + return get_impl()->get_attribute_value( + detail::ccl_api_type_attr_traits{}); +} + +template +CCL_API bool comm_attr::is_valid() const noexcept { + return get_impl()->is_valid(); +} + +} // namespace v1 + +} // namespace ccl diff --git a/src/comm_split_attr_impl.hpp b/src/comm_split_attr_impl.hpp index 9ef132211..b9e3db667 100644 --- a/src/comm_split_attr_impl.hpp +++ b/src/comm_split_attr_impl.hpp @@ -13,31 +13,35 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_comm_split_attr.hpp" - -namespace ccl { - -/** - * comm_split_attr attributes definition - */ -template -CCL_API Value comm_split_attr::set(const Value& v) { - return get_impl()->set_attribute_value( - v, details::ccl_api_type_attr_traits{}); -} - -template -CCL_API const typename details::ccl_api_type_attr_traits::type& -comm_split_attr::get() const { - return get_impl()->get_attribute_value( - details::ccl_api_type_attr_traits{}); -} - -template -CCL_API bool comm_split_attr::is_valid() const noexcept { - return get_impl()->is_valid(); -} - -} // namespace ccl +#pragma once +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/comm_split_attr.hpp" + +namespace ccl { + +namespace v1 { + +/** + * comm_split_attr attributes definition + */ +template +CCL_API Value comm_split_attr::set(const Value& v) { + return get_impl()->set_attribute_value( + v, detail::ccl_api_type_attr_traits{}); +} + +template +CCL_API const typename detail::ccl_api_type_attr_traits::type& +comm_split_attr::get() const { + return get_impl()->get_attribute_value( + detail::ccl_api_type_attr_traits{}); +} + +template +CCL_API bool comm_split_attr::is_valid() const noexcept { + return get_impl()->is_valid(); +} + +} // namespace v1 + +} // namespace ccl diff --git a/src/common/comm/atl_tag.cpp b/src/common/comm/atl_tag.cpp index 794290c9a..eea162428 100644 --- a/src/common/comm/atl_tag.cpp +++ b/src/common/comm/atl_tag.cpp @@ -30,7 +30,7 @@ void ccl_atl_tag::print() { } uint64_t ccl_atl_tag::create(ccl_comm_id_t comm_id, - size_t rank, + int rank, ccl_sched_id_t sched_id, ccl_op_id_t op_id) { uint64_t tag = 0; diff --git a/src/common/comm/atl_tag.hpp b/src/common/comm/atl_tag.hpp index 5a5ae0348..ea4570121 100644 --- a/src/common/comm/atl_tag.hpp +++ b/src/common/comm/atl_tag.hpp @@ -48,7 +48,7 @@ class ccl_atl_tag { * @param op_id local operation ID. Used to generate unique ATL tag when the rest of input parameters do not change * @return ATL communication tag */ - uint64_t create(ccl_comm_id_t comm_id, size_t rank, ccl_sched_id_t sched_id, ccl_op_id_t op_id); + uint64_t create(ccl_comm_id_t comm_id, int rank, ccl_sched_id_t sched_id, ccl_op_id_t op_id); private: /********************************************************************************** diff --git a/src/common/comm/comm.cpp b/src/common/comm/comm.cpp index c5290f1ce..b70ffdff4 100644 --- a/src/common/comm/comm.cpp +++ b/src/common/comm/comm.cpp @@ -18,11 +18,10 @@ #include "common/comm/comm.hpp" #include "common/global/global.hpp" #include "sched/sched.hpp" -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_kvs.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/kvs.hpp" -void ccl_comm::allocate_resources() -{ +void ccl_comm::allocate_resources() { if (ccl::global_data::env().enable_unordered_coll) { unordered_coll_manager = std::unique_ptr(new ccl_unordered_coll_manager(*this)); @@ -30,27 +29,26 @@ void ccl_comm::allocate_resources() auto& env_object = ccl::global_data::env(); - allreduce_2d_builder = std::unique_ptr( - new ccl_allreduce_2d_builder( - (env_object.allreduce_2d_base_size != CCL_ENV_SIZET_NOT_SPECIFIED) - ? env_object.allreduce_2d_base_size - : ccl::global_data::get().executor->get_local_proc_count(), - env_object.allreduce_2d_switch_dims, - this)); + allreduce_2d_builder = std::unique_ptr(new ccl_allreduce_2d_builder( + (env_object.allreduce_2d_base_size != CCL_ENV_SIZET_NOT_SPECIFIED) + ? env_object.allreduce_2d_base_size + : ccl::global_data::get().executor->get_local_proc_count(), + env_object.allreduce_2d_switch_dims, + this)); if (m_rank == 0) env_object.print(); } -ccl_comm::ccl_comm(size_t rank, - size_t size, +ccl_comm::ccl_comm(int rank, + int size, ccl_comm_id_storage::comm_id&& id, std::shared_ptr atl, bool share_resources) : ccl_comm(rank, size, std::move(id), ccl_rank2rank_map{}, atl, share_resources) {} -ccl_comm::ccl_comm(size_t rank, - size_t size, +ccl_comm::ccl_comm(int rank, + int size, ccl_comm_id_storage::comm_id&& id, ccl_rank2rank_map&& rank_map, std::shared_ptr atl, @@ -63,8 +61,7 @@ ccl_comm::ccl_comm(size_t rank, on_process_ranks_number(1) { reset(rank, size); - if (!share_resources) - { + if (!share_resources) { allocate_resources(); } } @@ -79,27 +76,24 @@ void ccl_comm::ccl_comm_reset_thread_barrier() { thread_ranks_counter.store(0); } -ccl_comm::ccl_comm(const std::vector& local_thread_device_ranks, - size_t cluster_devices_count, +ccl_comm::ccl_comm(const std::vector& local_ranks, + int comm_size, std::shared_ptr kvs_instance, ccl_comm_id_storage::comm_id&& id, bool share_resources) : m_id(std::move(id)), m_local2global_map(), - m_dtree(local_thread_device_ranks.size(), cluster_devices_count) { - + m_dtree(local_ranks.size(), comm_size) { std::shared_ptr kvs_wrapper(new users_kvs(kvs_instance)); - atl = std::shared_ptr( - new atl_wrapper(cluster_devices_count, local_thread_device_ranks, kvs_wrapper)); + atl = std::shared_ptr(new atl_wrapper(comm_size, local_ranks, kvs_wrapper)); - thread_number = atl->get_threads_count(); - on_process_ranks_number = atl->get_devices_per_rank_count(); + thread_number = atl->get_threads_per_process(); + on_process_ranks_number = atl->get_ranks_per_process(); reset(atl->get_rank(), atl->get_size()); - if (!share_resources) - { + if (!share_resources) { allocate_resources(); } } @@ -109,11 +103,11 @@ ccl_comm* ccl_comm::create_with_colors(const std::vector& colors, const ccl_comm* parent_comm, bool share_resources) { ccl_rank2rank_map rank_map; - size_t new_comm_size = 0; - size_t new_comm_rank = 0; + int new_comm_size = 0; + int new_comm_rank = 0; int color = colors[parent_comm->rank()]; - for (size_t i = 0; i < parent_comm->size(); ++i) { + for (int i = 0; i < parent_comm->size(); ++i) { if (colors[i] == color) { LOG_DEBUG("map local rank ", new_comm_size, " to global ", i); rank_map.emplace_back(i); @@ -134,8 +128,12 @@ ccl_comm* ccl_comm::create_with_colors(const std::vector& colors, rank_map.clear(); } - ccl_comm* comm = new ccl_comm( - new_comm_rank, new_comm_size, comm_ids->acquire(), std::move(rank_map), parent_comm->atl, share_resources); + ccl_comm* comm = new ccl_comm(new_comm_rank, + new_comm_size, + comm_ids->acquire(), + std::move(rank_map), + parent_comm->atl, + share_resources); LOG_DEBUG("new comm: color ", color, @@ -151,23 +149,24 @@ ccl_comm* ccl_comm::create_with_colors(const std::vector& colors, std::shared_ptr ccl_comm::clone_with_new_id(ccl_comm_id_storage::comm_id&& id) { ccl_rank2rank_map rank_map{ m_local2global_map }; - return std::make_shared(m_rank, m_size, std::move(id), std::move(rank_map), atl, true /*share_resources*/); + return std::make_shared( + m_rank, m_size, std::move(id), std::move(rank_map), atl, true /*share_resources*/); } -size_t ccl_comm::get_global_rank(size_t rank) const { +int ccl_comm::get_global_rank(int rank) const { if (m_local2global_map.empty()) { // global comm and its copies do not have entries in the map return rank; } - CCL_THROW_IF_NOT(m_local2global_map.size() > rank, + CCL_THROW_IF_NOT((int)m_local2global_map.size() > rank, "no rank ", rank, " was found in comm ", this, ", id ", m_id.value()); - size_t global_rank = m_local2global_map[rank]; + int global_rank = m_local2global_map[rank]; LOG_DEBUG( "comm , ", this, " id ", m_id.value(), ", map rank ", rank, " to global ", global_rank); return global_rank; diff --git a/src/common/comm/comm.hpp b/src/common/comm/comm.hpp index 03a081a2c..5a9e2072a 100644 --- a/src/common/comm/comm.hpp +++ b/src/common/comm/comm.hpp @@ -13,168 +13,169 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -#include -#include - -#include "atl/atl_wrapper.h" -#include "coll/algorithms/allreduce/allreduce_2d.hpp" -#include "common/comm/comm_id_storage.hpp" -#include "common/comm/atl_tag.hpp" -#include "common/log/log.hpp" -#include "common/utils/tree.hpp" -#include "common/utils/utils.hpp" -#include "unordered_coll/unordered_coll.hpp" - -// index = local_rank, value = global_rank -using ccl_rank2rank_map = std::vector; - -namespace ccl { -class kvs_interface; -} - -class alignas(CACHELINE_SIZE) ccl_comm { -public: - //TODO - static void ccl_comm_reset_thread_barrier(); - ccl_comm() = delete; - ccl_comm(const ccl_comm& other) = delete; - ccl_comm& operator=(const ccl_comm& other) = delete; - - ccl_comm(size_t rank, - size_t size, - ccl_comm_id_storage::comm_id&& id, - std::shared_ptr atl, - bool share_resources = false); - ccl_comm(size_t rank, - size_t size, - ccl_comm_id_storage::comm_id&& id, - ccl_rank2rank_map&& ranks, - std::shared_ptr atl, - bool share_resources = false); - - //TODO non-implemented - //1) cluster_devices_count (devices 1000) -> (processes 10) - //2) blocking until all thread -> calls ccl_comm - //3) return 'thread_count' - - // ccl_comm( {0,1,2,3...}, 1000, kvs ) - // from 20 processes from ranks 0,1,2,3. Each rank contains 10 threads - // communicator: size in {20} and ranks in {0..19} - // communicator: return threads count in process {10} - // communicator: return devices counts per thread in process - ccl_comm(const std::vector& local_thread_device_ranks, - size_t cluster_devices_count, - std::shared_ptr kvs_instance, - ccl_comm_id_storage::comm_id&& id, - bool share_resources = false); - - ~ccl_comm() = default; - - /* version with user-provided colors, allows to skip allgatherv */ - static ccl_comm* create_with_colors(const std::vector& colors, - ccl_comm_id_storage* comm_ids, - const ccl_comm* parent_comm, - bool share_resources = false); - - std::shared_ptr clone_with_new_id(ccl_comm_id_storage::comm_id&& id); - - size_t rank() const noexcept { - return m_rank; - } - - size_t size() const noexcept { - return m_size; - } - - size_t pof2() const noexcept { - return m_pof2; - } - - ccl_comm_id_t id() const noexcept { - return m_id.value(); - } - - size_t thread_count() const noexcept { - return thread_number; - } - - size_t on_process_ranks_count() const noexcept { - return on_process_ranks_number; - } - - ccl_sched_id_t get_sched_id(bool use_internal_space) { - ccl_sched_id_t& next_sched_id = - (use_internal_space) ? m_next_sched_id_internal : m_next_sched_id_external; - - ccl_sched_id_t first_sched_id = - (use_internal_space) ? static_cast(0) : ccl_comm::max_sched_count / 2; - - ccl_sched_id_t max_sched_id = - (use_internal_space) ? ccl_comm::max_sched_count / 2 : ccl_comm::max_sched_count; - - ccl_sched_id_t id = next_sched_id; - - ++next_sched_id; - - if (next_sched_id == max_sched_id) { - /* wrap the sched numbers around to the start */ - next_sched_id = first_sched_id; - } - - LOG_DEBUG("sched_id ", id, ", comm_id ", m_id.value(), ", next sched_id ", next_sched_id); - - return id; - } - - void reset(size_t rank, size_t size) { - m_rank = rank; - m_size = size; - m_pof2 = ccl_pof2(m_size); - - m_next_sched_id_internal = ccl_comm::max_sched_count / 2; - m_next_sched_id_external = 0; - } - - /** - * Returns the number of @c rank in the global communicator - * @param rank a rank which is part of the current communicator - * @return number of @c rank in the global communicator - */ - size_t get_global_rank(size_t rank) const; - - const ccl_double_tree& dtree() const { - return m_dtree; - } - - /** - * Maximum available number of active communicators - */ - static constexpr ccl_sched_id_t max_comm_count = std::numeric_limits::max(); - /** - * Maximum value of schedule id in scope of the current communicator - */ - static constexpr ccl_sched_id_t max_sched_count = std::numeric_limits::max(); - - std::shared_ptr atl; - std::unique_ptr unordered_coll_manager; - std::unique_ptr allreduce_2d_builder; - -private: - - void allocate_resources(); - - size_t m_rank; - size_t m_size; - size_t m_pof2; - - ccl_comm_id_storage::comm_id m_id; - ccl_sched_id_t m_next_sched_id_internal; - ccl_sched_id_t m_next_sched_id_external; - ccl_rank2rank_map m_local2global_map{}; - ccl_double_tree m_dtree; - - size_t thread_number; - size_t on_process_ranks_number; -}; +#pragma once + +#include +#include + +#include "atl/atl_wrapper.h" +#include "coll/algorithms/allreduce/allreduce_2d.hpp" +#include "common/comm/comm_id_storage.hpp" +#include "common/comm/atl_tag.hpp" +#include "common/log/log.hpp" +#include "common/utils/tree.hpp" +#include "common/utils/utils.hpp" +#include "unordered_coll/unordered_coll.hpp" + +// index = local_rank, value = global_rank +using ccl_rank2rank_map = std::vector; + +namespace ccl { +namespace v1 { +class kvs_interface; +} +} // namespace ccl + +class alignas(CACHELINE_SIZE) ccl_comm { +public: + //TODO + static void ccl_comm_reset_thread_barrier(); + ccl_comm() = delete; + ccl_comm(const ccl_comm& other) = delete; + ccl_comm& operator=(const ccl_comm& other) = delete; + + ccl_comm(int rank, + int size, + ccl_comm_id_storage::comm_id&& id, + std::shared_ptr atl, + bool share_resources = false); + ccl_comm(int rank, + int size, + ccl_comm_id_storage::comm_id&& id, + ccl_rank2rank_map&& ranks, + std::shared_ptr atl, + bool share_resources = false); + + //TODO non-implemented + //1) cluster_devices_count (devices 1000) -> (processes 10) + //2) blocking until all thread -> calls ccl_comm + //3) return 'thread_count' + + // ccl_comm( {0,1,2,3...}, 1000, kvs ) + // from 20 processes from ranks 0,1,2,3. Each rank contains 10 threads + // communicator: size in {20} and ranks in {0..19} + // communicator: return threads count in process {10} + // communicator: return devices counts per thread in process + ccl_comm(const std::vector& local_ranks, + int comm_size, + std::shared_ptr kvs_instance, + ccl_comm_id_storage::comm_id&& id, + bool share_resources = false); + + ~ccl_comm() = default; + + /* version with user-provided colors, allows to skip allgatherv */ + static ccl_comm* create_with_colors(const std::vector& colors, + ccl_comm_id_storage* comm_ids, + const ccl_comm* parent_comm, + bool share_resources = false); + + std::shared_ptr clone_with_new_id(ccl_comm_id_storage::comm_id&& id); + + int rank() const noexcept { + return m_rank; + } + + int size() const noexcept { + return m_size; + } + + int pof2() const noexcept { + return m_pof2; + } + + ccl_comm_id_t id() const noexcept { + return m_id.value(); + } + + size_t thread_count() const noexcept { + return thread_number; + } + + size_t ranks_per_process() const noexcept { + return on_process_ranks_number; + } + + ccl_sched_id_t get_sched_id(bool use_internal_space) { + ccl_sched_id_t& next_sched_id = + (use_internal_space) ? m_next_sched_id_internal : m_next_sched_id_external; + + ccl_sched_id_t first_sched_id = + (use_internal_space) ? static_cast(0) : ccl_comm::max_sched_count / 2; + + ccl_sched_id_t max_sched_id = + (use_internal_space) ? ccl_comm::max_sched_count / 2 : ccl_comm::max_sched_count; + + ccl_sched_id_t id = next_sched_id; + + ++next_sched_id; + + if (next_sched_id == max_sched_id) { + /* wrap the sched numbers around to the start */ + next_sched_id = first_sched_id; + } + + LOG_DEBUG("sched_id ", id, ", comm_id ", m_id.value(), ", next sched_id ", next_sched_id); + + return id; + } + + void reset(int rank, int size) { + m_rank = rank; + m_size = size; + m_pof2 = ccl_pof2(m_size); + + m_next_sched_id_internal = ccl_comm::max_sched_count / 2; + m_next_sched_id_external = 0; + } + + /** + * Returns the number of @c rank in the global communicator + * @param rank a rank which is part of the current communicator + * @return number of @c rank in the global communicator + */ + int get_global_rank(int rank) const; + + const ccl_double_tree& dtree() const { + return m_dtree; + } + + /** + * Maximum available number of active communicators + */ + static constexpr ccl_sched_id_t max_comm_count = std::numeric_limits::max(); + /** + * Maximum value of schedule id in scope of the current communicator + */ + static constexpr ccl_sched_id_t max_sched_count = std::numeric_limits::max(); + + std::shared_ptr atl; + std::unique_ptr unordered_coll_manager; + std::unique_ptr allreduce_2d_builder; + +private: + void allocate_resources(); + + int m_rank; + int m_size; + int m_pof2; + + ccl_comm_id_storage::comm_id m_id; + ccl_sched_id_t m_next_sched_id_internal; + ccl_sched_id_t m_next_sched_id_external; + ccl_rank2rank_map m_local2global_map{}; + ccl_double_tree m_dtree; + + size_t thread_number; + size_t on_process_ranks_number; +}; diff --git a/src/common/comm/comm_common_attr.hpp b/src/common/comm/comm_common_attr.hpp new file mode 100644 index 000000000..af0f7d0c9 --- /dev/null +++ b/src/common/comm/comm_common_attr.hpp @@ -0,0 +1,52 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/comm_attr_ids_traits.hpp" + +namespace ccl { + +class ccl_comm_attr_impl { +public: + /** + * `version` operations + */ + using version_traits_t = detail::ccl_api_type_attr_traits; + + const typename version_traits_t::return_type& get_attribute_value( + const version_traits_t& id) const { + return version; + } + + typename version_traits_t::return_type set_attribute_value(typename version_traits_t::type val, + const version_traits_t& t) { + (void)t; + throw ccl::exception("Set value for 'ccl::comm_attr_id::version' is not allowed"); + return version; + } + + ccl_comm_attr_impl(const typename version_traits_t::return_type& version) : version(version) {} + + template + bool is_valid() const noexcept { + return (attr_id == comm_attr_id::version); + } + +protected: + typename version_traits_t::return_type version; +}; + +} // namespace ccl diff --git a/src/common/comm/comm_id_storage.hpp b/src/common/comm/comm_id_storage.hpp index 7025b8606..627bc7c9b 100644 --- a/src/common/comm/comm_id_storage.hpp +++ b/src/common/comm/comm_id_storage.hpp @@ -15,7 +15,7 @@ */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" +#include "oneapi/ccl/types.hpp" #include "common/log/log.hpp" #include "common/utils/spinlock.hpp" @@ -127,7 +127,6 @@ class ccl_comm_id_storage { ccl_comm_id_t acquire_id_impl(ccl_comm_id_t last_used, ccl_comm_id_t lower_bound, ccl_comm_id_t upper_bound) { - //search from the current position till the end LOG_DEBUG("last ", last_used, ", low ", lower_bound, " up ", upper_bound); diff --git a/src/common/comm/comm_interface.cpp b/src/common/comm/comm_interface.cpp index 7c270c28b..e5a153671 100644 --- a/src/common/comm/comm_interface.cpp +++ b/src/common/comm/comm_interface.cpp @@ -16,6 +16,9 @@ #include "common/comm/comm_interface.hpp" #include "common/comm/compiler_comm_interface_dispatcher_impl.hpp" - -COMMUNICATOR_INTERFACE_DISPATCHER_CLASS_EXPLICIT_INSTANTIATION(typename ccl::unified_device_type::ccl_native_t, typename ccl::unified_device_context_type::ccl_native_t); -COMMUNICATOR_INTERFACE_DISPATCHER_NON_CLASS_EXPLICIT_INSTANTIATION(ccl::device_index_type, typename ccl::unified_device_context_type::ccl_native_t); +COMMUNICATOR_INTERFACE_DISPATCHER_CLASS_EXPLICIT_INSTANTIATION( + typename ccl::unified_device_type::ccl_native_t, + typename ccl::unified_context_type::ccl_native_t); +COMMUNICATOR_INTERFACE_DISPATCHER_NON_CLASS_EXPLICIT_INSTANTIATION( + ccl::device_index_type, + typename ccl::unified_context_type::ccl_native_t); diff --git a/src/common/comm/comm_interface.hpp b/src/common/comm/comm_interface.hpp index ab6e119df..1e813d8c5 100644 --- a/src/common/comm/comm_interface.hpp +++ b/src/common/comm/comm_interface.hpp @@ -14,32 +14,30 @@ limitations under the License. */ #pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_type_traits.hpp" -#include "oneapi/ccl/ccl_types_policy.hpp" -#include "oneapi/ccl/ccl_event.hpp" -#include "oneapi/ccl/ccl_comm_split_attr_ids.hpp" -#include "oneapi/ccl/ccl_comm_split_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_comm_split_attr.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/type_traits.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/event.hpp" -#include "oneapi/ccl/ccl_stream_attr_ids.hpp" -#include "oneapi/ccl/ccl_stream_attr_ids_traits.hpp" -#include "oneapi/ccl/ccl_stream.hpp" +#include "oneapi/ccl/comm_split_attr_ids.hpp" +#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_split_attr.hpp" -#include "common/event/event_internal/event_internal_attr_ids.hpp" -#include "common/event/event_internal/event_internal_attr_ids_traits.hpp" -#include "common/event/event_internal/event_internal.hpp" +#include "oneapi/ccl/stream_attr_ids.hpp" +#include "oneapi/ccl/stream_attr_ids_traits.hpp" +#include "oneapi/ccl/stream.hpp" #include "common/comm/compiler_comm_interface_dispatcher.hpp" #include "common/comm/l0/comm_context_id.hpp" +#include "internal_types.hpp" namespace native { struct ccl_device; } namespace ccl { -struct gpu_comm_attr; +namespace v1 { class allgatherv_attr; class allreduce_attr; class alltoall_attr; @@ -49,16 +47,94 @@ class broadcast_attr; class reduce_attr; class reduce_scatter_attr; class sparse_allreduce_attr; +} // namespace v1 + +struct gpu_comm_attr; } // namespace ccl #include "types_generator_defines.hpp" +#define COMM_INTERFACE_COLL_METHODS(TYPE) \ +\ + COMM_INTERFACE_COLL_##TYPE##__VOID; \ + COMM_INTERFACE_COLL_##TYPE(int8_t); \ + COMM_INTERFACE_COLL_##TYPE(uint8_t); \ + COMM_INTERFACE_COLL_##TYPE(int16_t); \ + COMM_INTERFACE_COLL_##TYPE(uint16_t); \ + COMM_INTERFACE_COLL_##TYPE(int32_t); \ + COMM_INTERFACE_COLL_##TYPE(uint32_t); \ + COMM_INTERFACE_COLL_##TYPE(int64_t); \ + COMM_INTERFACE_COLL_##TYPE(uint64_t); \ + COMM_INTERFACE_COLL_##TYPE(float); \ + COMM_INTERFACE_COLL_##TYPE(double); \ +\ + COMM_INTERFACE_SPARSE_##TYPE##__VOID; \ + COMM_INTERFACE_SPARSE_##TYPE(int32_t, ccl::bfloat16); \ + COMM_INTERFACE_SPARSE_##TYPE(int32_t, float); \ + COMM_INTERFACE_SPARSE_##TYPE(int64_t, ccl::bfloat16); \ + COMM_INTERFACE_SPARSE_##TYPE(int64_t, float); + +#define SYCL_COMM_INTERFACE_COLL_METHODS(TYPE) \ + COMM_INTERFACE_COLL_CLASS_##TYPE(cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_##TYPE(cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_##TYPE(cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_##TYPE(cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_##TYPE(cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_##TYPE(cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_##TYPE(cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_##TYPE(cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_##TYPE(cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_##TYPE(cl::sycl::buffer); \ +\ + COMM_INTERFACE_SPARSE_CLASS_##TYPE(cl::sycl::buffer, \ + cl::sycl::buffer); \ + COMM_INTERFACE_SPARSE_CLASS_##TYPE(cl::sycl::buffer, \ + cl::sycl::buffer); \ +\ + COMM_INTERFACE_SPARSE_CLASS_##TYPE(cl::sycl::buffer, \ + cl::sycl::buffer); \ + COMM_INTERFACE_SPARSE_CLASS_##TYPE(cl::sycl::buffer, \ + cl::sycl::buffer); + +#define COMM_INTERFACE_COLL_INSTANTIATION(COMM) \ + COMM_INTERFACE_COLL_INSTANTIATIONS(COMM, int8_t); \ + COMM_INTERFACE_COLL_INSTANTIATIONS(COMM, uint8_t); \ + COMM_INTERFACE_COLL_INSTANTIATIONS(COMM, int16_t); \ + COMM_INTERFACE_COLL_INSTANTIATIONS(COMM, uint16_t); \ + COMM_INTERFACE_COLL_INSTANTIATIONS(COMM, int32_t); \ + COMM_INTERFACE_COLL_INSTANTIATIONS(COMM, uint32_t); \ + COMM_INTERFACE_COLL_INSTANTIATIONS(COMM, int64_t); \ + COMM_INTERFACE_COLL_INSTANTIATIONS(COMM, uint64_t); \ + COMM_INTERFACE_COLL_INSTANTIATIONS(COMM, float); \ + COMM_INTERFACE_COLL_INSTANTIATIONS(COMM, double); \ + COMM_INTERFACE_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(COMM, int32_t, float); \ + COMM_INTERFACE_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(COMM, int32_t, ccl::bfloat16); \ + COMM_INTERFACE_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(COMM, int64_t, float); \ + COMM_INTERFACE_SPARSE_ALLREDUCE_EXPLICIT_INSTANTIATION(COMM, int64_t, ccl::bfloat16); + +#define SYCL_COMM_INTERFACE_COLL_INSTANTIATION(COMM) \ + COMM_INTERFACE_COLL_CLASS_INSTANTIATIONS(COMM, cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_INSTANTIATIONS(COMM, cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_INSTANTIATIONS(COMM, cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_INSTANTIATIONS(COMM, cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_INSTANTIATIONS(COMM, cl::sycl::buffer); \ + COMM_INTERFACE_COLL_CLASS_INSTANTIATIONS(COMM, cl::sycl::buffer); \ +\ + COMM_INTERFACE_SPARSE_ALLREDUCE_EXPLICIT_CLASS_INSTANTIATION( \ + COMM, cl::sycl::buffer, cl::sycl::buffer); \ + COMM_INTERFACE_SPARSE_ALLREDUCE_EXPLICIT_CLASS_INSTANTIATION( \ + COMM, cl::sycl::buffer, cl::sycl::buffer); \ + COMM_INTERFACE_SPARSE_ALLREDUCE_EXPLICIT_CLASS_INSTANTIATION( \ + COMM, cl::sycl::buffer, cl::sycl::buffer); \ + COMM_INTERFACE_SPARSE_ALLREDUCE_EXPLICIT_CLASS_INSTANTIATION( \ + COMM, cl::sycl::buffer, cl::sycl::buffer); + namespace ccl { struct communicator_interface : public communicator_interface_dispatcher { virtual ~communicator_interface() = default; - virtual size_t rank() const = 0; - virtual size_t size() const = 0; + virtual int rank() const = 0; + virtual int size() const = 0; virtual bool is_host() const noexcept = 0; virtual bool is_cpu() const noexcept = 0; @@ -73,66 +149,12 @@ struct communicator_interface : public communicator_interface_dispatcher { // collectives operation declarations virtual ccl::event barrier(const stream::impl_value_t& op_stream, - const barrier_attr& attr, - const vector_class& deps = {}) = 0; - - DEVICE_COMM_INTERFACE_COLL_DECLARATION__VOID; - DEVICE_COMM_INTERFACE_COLL_DECLARATION(char); - DEVICE_COMM_INTERFACE_COLL_DECLARATION(int); - DEVICE_COMM_INTERFACE_COLL_DECLARATION(int64_t); - DEVICE_COMM_INTERFACE_COLL_DECLARATION(uint64_t); - DEVICE_COMM_INTERFACE_COLL_DECLARATION(float); - DEVICE_COMM_INTERFACE_COLL_DECLARATION(double); - -#ifdef CCL_ENABLE_SYCL - DEVICE_COMM_INTERFACE_COLL_CLASS_DECLARATION(cl::sycl::buffer); - DEVICE_COMM_INTERFACE_COLL_CLASS_DECLARATION(cl::sycl::buffer); - DEVICE_COMM_INTERFACE_COLL_CLASS_DECLARATION(cl::sycl::buffer); - DEVICE_COMM_INTERFACE_COLL_CLASS_DECLARATION(cl::sycl::buffer); - DEVICE_COMM_INTERFACE_COLL_CLASS_DECLARATION(cl::sycl::buffer); - DEVICE_COMM_INTERFACE_COLL_CLASS_DECLARATION(cl::sycl::buffer); -#endif //CCL_ENABLE_SYCL - - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION__VOID - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(char, char); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(char, int); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(char, ccl::bf16); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(char, float); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(char, double); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(char, int64_t); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(char, uint64_t); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int, char); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int, int); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int, ccl::bf16); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int, float); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int, double); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int, int64_t); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int, uint64_t); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int64_t, char); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int64_t, int); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int64_t, ccl::bf16); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int64_t, float); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int64_t, double); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int64_t, int64_t); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(int64_t, uint64_t); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(uint64_t, char); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(uint64_t, int); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(uint64_t, ccl::bf16); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(uint64_t, float); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(uint64_t, double); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(uint64_t, int64_t); - DEVICE_COMM_INTERFACE_SPARSE_DECLARATION(uint64_t, uint64_t); + const barrier_attr& attr, + const vector_class& deps = {}) = 0; + COMM_INTERFACE_COLL_METHODS(DECLARATION); #ifdef CCL_ENABLE_SYCL - DEVICE_COMM_INTERFACE_SPARSE_CLASS_DECLARATION(cl::sycl::buffer, - cl::sycl::buffer); - DEVICE_COMM_INTERFACE_SPARSE_CLASS_DECLARATION(cl::sycl::buffer, - cl::sycl::buffer); - - DEVICE_COMM_INTERFACE_SPARSE_CLASS_DECLARATION(cl::sycl::buffer, - cl::sycl::buffer); - DEVICE_COMM_INTERFACE_SPARSE_CLASS_DECLARATION(cl::sycl::buffer, - cl::sycl::buffer); -#endif //CCL_ENABLE_SYCL + SYCL_COMM_INTERFACE_COLL_METHODS(DECLARATION); +#endif /* CCL_ENABLE_SYCL */ }; } // namespace ccl diff --git a/src/common/comm/comm_split_common_attr.hpp b/src/common/comm/comm_split_common_attr.hpp index d6cf1fa65..f59e9d90f 100644 --- a/src/common/comm/comm_split_common_attr.hpp +++ b/src/common/comm/comm_split_common_attr.hpp @@ -13,148 +13,150 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once -#include "oneapi/ccl/ccl_types.hpp" -#include "oneapi/ccl/ccl_comm_split_attr_ids_traits.hpp" - -namespace ccl { - -/** - * Base implementation - */ -template