diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b8d9b4d..6443aed2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,8 @@ option(WITH_TORCH_DEBUG "Compute RMSE of Surrogate Model and Physics Module" option(WITH_TESTS "Compile tests" OFF) option(WITH_REDIS "Use REDIS as a database back end" OFF) option(WITH_HDF5 "Use HDF5 as a database back end" OFF) +option(HDF5_USE_STATIC_LIBRARIES "Use static HDF5." OFF) +set(HDF5_WITH_ZLIB "" CACHE FILETYPE "Use the following zlib for HDF5") option(WITH_RMQ "Use RabbitMQ as a database back end (require a reachable and running RabbitMQ server service)" OFF) option(WITH_AMS_DEBUG "Enable verbose messages" OFF) option(WITH_PERFFLOWASPECT "Use PerfFlowAspect for Profiling" OFF) @@ -46,6 +48,7 @@ option(WITH_WORKFLOW "Install python drivers used by the outer workflow" O option(WITH_AMS_LIB "Install C++ library to support scientific applications" ON) option(WITH_ADIAK "Use Adiak for recording metadata" OFF) option(BUILD_SHARED_LIBS "Build using shared libraries" ON) +option(EXCLUDE_STATIC_LIBS "Exclude static libs from the linking line" OFF) if (WITH_MPI) # SET(CMAKE_CXX_COMPILER "${MPI_CXX_COMPILER}" CACHE FILEPATH "CXX compiler overridden with MPI C++ wrapper") @@ -69,6 +72,7 @@ if (WITH_CUDA) if (BUILD_SHARED_LIBS) set(CUDA_RUNTIME_LIBRARY "Shared") else() + set(HDF5_USE_STATIC_LIBRARIES ON) set(CUDA_RUNTIME_LIBRARY "Static") endif() @@ -91,7 +95,7 @@ if (WITH_CALIPER) endif() if (WITH_AMS_DEBUG) - list(APPEND AMS_APP_DEFINES "-DLIBAMS_VERBOSE") + list(APPEND AMS_APP_DEFINES "-DAMS_DEBUG") endif() # ------------------------------------------------------------------------------ @@ -132,14 +136,17 @@ endif() # WITH_REDIS if (WITH_HDF5) if (HDF5_USE_STATIC_LIBRARIES) - find_package(HDF5 NAMES hdf5 COMPONENTS C static NO_DEFAULT_PATH PATHS ${AMS_HDF5_DIR} ${AMS_HDF5_DIR}/share/cmake) - list(APPEND AMS_APP_LIBRARIES ${HDF5_C_STATIC_LIBRARY}) - message(STATUS "HDF5 Static Library : ${HDF5_C_STATIC_LIBRARY}") -else() - find_package(HDF5 NAMES hdf5 COMPONENTS C shared NO_DEFAULT_PATH PATHS ${AMS_HDF5_DIR} ${AMS_HDF5_DIR}/share/cmake) - list(APPEND AMS_APP_LIBRARIES ${HDF5_C_SHARED_LIBRARY}) - message(STATUS "HDF5 Shared Library : ${HDF5_C_SHARED_LIBRARY}") -endif() + find_package(HDF5 NAMES hdf5 COMPONENTS C static NO_DEFAULT_PATH PATHS ${AMS_HDF5_DIR} ${AMS_HDF5_DIR}/share/cmake) + list(APPEND AMS_APP_LIBRARIES ${HDF5_C_STATIC_LIBRARY}) + message(STATUS "HDF5 Static Library : ${HDF5_C_STATIC_LIBRARY}") + else() + find_package(HDF5 NAMES hdf5 COMPONENTS C shared NO_DEFAULT_PATH PATHS ${AMS_HDF5_DIR} ${AMS_HDF5_DIR}/share/cmake) + list(APPEND AMS_APP_LIBRARIES ${HDF5_C_SHARED_LIBRARY}) + message(STATUS "HDF5 Shared Library : ${HDF5_C_SHARED_LIBRARY}") + endif() + if (NOT HDF5_WITH_ZLIB STREQUAL "") + list(APPEND AMS_APP_LIBRARIES ${HDF5_WITH_ZLIB}) + endif() list(APPEND AMS_APP_INCLUDES ${HDF5_INCLUDE_DIR}) list(APPEND AMS_APP_DEFINES "-D__ENABLE_HDF5__") message(STATUS "HDF5 Include directories: ${HDF5_INCLUDE_DIR}") @@ -191,16 +198,50 @@ if (WITH_TORCH) find_package(Torch REQUIRED) # This is annoying, torch populates all my cuda flags # and resets them - set(CMAKE_CUDA_FLAGS "") + # set(CMAKE_CUDA_FLAGS "") set(CMAKE_CUDA_ARCHITECTURES ON) - list(APPEND AMS_APP_INCLUDES "${TORCH_INCLUDE_DIRS}") - list(APPEND AMS_APP_LIBRARIES "${TORCH_LIBRARIES}") - + #get_target_property(torch_interface_system_includes + # torch INTERFACE_SYSTEM_INCLUDE_DIRECTORIES) + #if ( torch_interface_system_includes ) + # list(APPEND AMS_APP_INCLUDES ${torch_interface_system_includes}) + #endif() + # + #get_target_property(torch_interface_includes + # torch INTERFACE_INCLUDE_DIRECTORIES) + #if ( torch_interface_includes ) + # list(APPEND AMS_APP_INCLUDES ${torch_interface_includes}) + #endif() + # + #get_target_property(torch_interface_defines + # torch INTERFACE_COMPILE_DEFINITIONS) + #if ( troch_interface_defines ) + # list(APPEND AMS_APP_DEFINES ${torch_interface_defines}) + #endif() + # + #get_target_property(torch_interface_compile_options + # torch INTERFACE_COMPILE_OPTIONS) + #if ( torch_interface_compile_options ) + # list(APPEND AMS_APP_DEFINES ${torch_interface_compile_options}) + #endif() + # + #get_target_property(_interface_link_directories + # ${arg_FROM} INTERFACE_LINK_DIRECTORIES) + #if ( _interface_link_directories ) + # target_link_directories( ${arg_TO} ${_scope} ${_interface_link_directories}) + #endif() + # + #get_target_property(torch_interface_link_libraries + # torch INTERFACE_LINK_LIBRARIES) + #if ( torch_interface_link_libraries ) + # list(APPEND AMS_APP_LIBRARIES "${torch_interface_link_libraries}") + #endif() + + list(APPEND AMS_TORCH_LIBRARY torch) list(APPEND AMS_APP_DEFINES "-D__ENABLE_TORCH__") - set(BLA_VENDER OpenBLAS) - find_package(BLAS REQUIRED) - list(APPEND AMS_APP_LIBRARIES "${BLAS_LIBRARIES}") + #set(BLA_VENDER OpenBLAS) + #find_package(BLAS REQUIRED) + #list(APPEND AMS_APP_LIBRARIES "${BLAS_LIBRARIES}") endif() # ------------------------------------------------------------------------------ @@ -253,6 +294,98 @@ if (WITH_PERFFLOWASPECT) endif() +macro(inherit_target_nostatic) + set(options) + set(singleValueArgs TO FROM OBJECT) + set(multiValueArgs) + + # Parse the arguments + cmake_parse_arguments(arg "${options}" "${singleValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + # Check arguments + if ( NOT DEFINED arg_TO ) + message( FATAL_ERROR "Must provide a TO argument to the 'blt_inherit_target' macro" ) + endif() + + if ( NOT DEFINED arg_FROM ) + message( FATAL_ERROR "Must provide a FROM argument to the 'blt_inherit_target' macro" ) + endif() + + set(_scope INTERFACE) + + get_target_property(_interface_system_includes + ${arg_FROM} INTERFACE_SYSTEM_INCLUDE_DIRECTORIES) + if ( _interface_system_includes ) + target_include_directories(${arg_TO} SYSTEM ${_scope} ${_interface_system_includes}) + endif() + + get_target_property(_interface_includes + ${arg_FROM} INTERFACE_INCLUDE_DIRECTORIES) + if ( _interface_includes ) + target_include_directories(${arg_TO} ${_scope} ${_interface_includes}) + endif() + + get_target_property(_interface_defines + ${arg_FROM} INTERFACE_COMPILE_DEFINITIONS) + if ( _interface_defines ) + target_compile_definitions( ${arg_TO} ${_scope} ${_interface_defines}) + endif() + + if( ${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.13.0" ) + get_target_property(_interface_link_options + ${arg_FROM} INTERFACE_LINK_OPTIONS) + if ( _interface_link_options ) + target_link_options( ${arg_TO} ${_scope} ${_interface_link_options}) + endif() + endif() + + get_target_property(_interface_compile_options + ${arg_FROM} INTERFACE_COMPILE_OPTIONS) + if ( _interface_compile_options ) + target_compile_options( ${arg_TO} ${_scope} ${_interface_compile_options}) + endif() + + if ( NOT arg_OBJECT ) + #get_target_property(_interface_link_directories + # ${arg_FROM} INTERFACE_LINK_DIRECTORIES) + #if ( _interface_link_directories ) + # target_link_directories( ${arg_TO} ${_scope} ${_interface_link_directories}) + #endif() + + #get_target_property(_interface_link_libraries + # ${arg_FROM} INTERFACE_LINK_LIBRARIES) + #if ( _interface_link_libraries ) + # target_link_libraries( ${arg_TO} ${_scope} ${_interface_link_libraries}) + #endif() + endif() + +endmacro(inherit_target_nostatic) + + +if (EXCLUDE_STATIC_LIBS) + set(NONSTATIC_AMS_APP_INTERFACE_LIBRARIES "") + set(NONSTATIC_AMS_APP_PRIVATE_LIBRARIES "") + foreach (THIS_LIB ${AMS_APP_LIBRARIES}) + list(APPEND NONSTATIC_AMS_APP_INTERFACE_LIBRARIES ${THIS_LIB}) + if (TARGET ${THIS_LIB}) + get_target_property(target_type ${THIS_LIB} TYPE) + if (NOT target_type STREQUAL STATIC_LIBRARY) + list(APPEND NONSTATIC_AMS_APP_PRIVATE_LIBRARIES ${THIS_LIB}) + else() + add_library("${THIS_LIB}::nostatic" INTERFACE IMPORTED) + inherit_target_nostatic(TO "${THIS_LIB}::nostatic" FROM ${THIS_LIB}) + list(APPEND NONSTATIC_AMS_APP_PRIVATE_LIBRARIES "${THIS_LIB}::nostatic") + endif() + else() + get_filename_component(THIS_EXT ${THIS_LIB} EXT) + if (NOT ".a" STREQUAL "${THIS_EXT}") + list(APPEND NONSTATIC_AMS_APP_PRIVATE_LIBRARIES ${THIS_LIB}) + endif() + endif() + endforeach() +endif() + add_subdirectory(src) # ------------------------------------------------------------------------------ diff --git a/examples/main.cpp b/examples/main.cpp index ed4f7a87..67ce65bb 100644 --- a/examples/main.cpp +++ b/examples/main.cpp @@ -179,13 +179,13 @@ int run(const char *device_name, CALIPER(CALI_MARK_BEGIN("Setup");) const bool use_device = std::strcmp(device_name, "cpu") != 0; - AMSDBType dbType = AMSDBType::None; + AMSDBType dbType = AMSDBType::DDDBNone; if (std::strcmp(db_type, "csv") == 0) { - dbType = AMSDBType::CSV; + dbType = AMSDBType::DBCSV; } else if (std::strcmp(db_type, "hdf5") == 0) { - dbType = AMSDBType::HDF5; + dbType = AMSDBType::DBHDF5; } else if (std::strcmp(db_type, "rmq") == 0) { - dbType = AMSDBType::RMQ; + dbType = AMSDBType::DBRMQ; } AMSUQPolicy uq_policy; diff --git a/src/AMSlib/CMakeLists.txt b/src/AMSlib/CMakeLists.txt index 8341439d..e00c95e4 100644 --- a/src/AMSlib/CMakeLists.txt +++ b/src/AMSlib/CMakeLists.txt @@ -42,7 +42,13 @@ target_include_directories(AMS PUBLIC $) target_include_directories(AMS PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_link_directories(AMS PUBLIC ${AMS_APP_LIB_DIRS}) -target_link_libraries(AMS PUBLIC ${AMS_APP_LIBRARIES} stdc++fs) +if (EXCLUDE_STATIC_LIBS) + target_link_libraries(AMS PRIVATE ${NONSTATIC_AMS_APP_PRIVATE_LIBRARIES} ${AMS_TORCH_LIBRARY} stdc++fs stdc++ /usr/tce/packages/intel/intel-2022.1.0/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libintlc.so m) + target_link_libraries(AMS INTERFACE ${NONSTATIC_AMS_APP_INTERFACE_LIBRARIES} stdc++fs stdc++ /usr/tce/packages/intel/intel-2022.1.0/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libintlc.so m) + target_link_options(AMS PRIVATE "-Wl,--unresolved-symbols=ignore-all") +else() + target_link_libraries(AMS PUBLIC ${AMS_APP_LIBRARIES} stdc++fs stdc++ /usr/tce/packages/intel/intel-2022.1.0/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libintlc.so m) +endif() #------------------------------------------------------------------------------- # create the configuration header file with the respective information @@ -75,7 +81,7 @@ install(TARGETS AMS DESTINATION lib) install(EXPORT AMSTargets - FILE AMS.cmake + FILE AMSConfig.cmake DESTINATION lib/cmake/AMS) install(FILES ${PROJECT_BINARY_DIR}/include/AMS.h DESTINATION include) diff --git a/src/AMSlib/include/AMS.h b/src/AMSlib/include/AMS.h index 64488948..51e6a8a8 100644 --- a/src/AMSlib/include/AMS.h +++ b/src/AMSlib/include/AMS.h @@ -55,7 +55,7 @@ typedef enum { typedef enum { UBALANCED = 0, BALANCED } AMSExecPolicy; -typedef enum { None = 0, CSV, REDIS, HDF5, RMQ } AMSDBType; +typedef enum { DDDBNone = 0, DBCSV, DBREDIS, DBHDF5, DBRMQ } AMSDBType; // TODO: create a cleaner interface that separates UQ type (FAISS, DeltaUQ) with policy (max, mean). enum struct AMSUQPolicy { diff --git a/src/AMSlib/ml/surrogate.hpp b/src/AMSlib/ml/surrogate.hpp index 997ad17a..e9e7751f 100644 --- a/src/AMSlib/ml/surrogate.hpp +++ b/src/AMSlib/ml/surrogate.hpp @@ -410,6 +410,8 @@ class SurrogateModel else _load(new_path, "cuda"); } + + AMSResourceType getModelResource() const { return model_resource; } }; template diff --git a/src/AMSlib/ml/uq.hpp b/src/AMSlib/ml/uq.hpp index dabe140f..6b6227bb 100644 --- a/src/AMSlib/ml/uq.hpp +++ b/src/AMSlib/ml/uq.hpp @@ -63,6 +63,8 @@ class UQ if (uqPolicy == AMSUQPolicy::RandomUQ) randomUQ = std::make_unique(resourceLocation, threshold); + + DBG(UQ, "UQ Model is of type %d", uqPolicy) } PERFFASPECT() @@ -73,21 +75,40 @@ class UQ { if ((uqPolicy == AMSUQPolicy::DeltaUQ_Mean) || (uqPolicy == AMSUQPolicy::DeltaUQ_Max)) { + + auto &rm = ams::ResourceManager::getInstance(); + CALIPER(CALI_MARK_BEGIN("DELTAUQ");) const size_t ndims = outputs.size(); std::vector outputs_stdev(ndims); // TODO: Enable device-side allocation and predicate calculation. - auto &rm = ams::ResourceManager::getInstance(); for (int dim = 0; dim < ndims; ++dim) outputs_stdev[dim] = rm.allocate(totalElements, AMSResourceType::HOST); CALIPER(CALI_MARK_BEGIN("SURROGATE");) - DBG(Workflow, - "Model exists, I am calling DeltaUQ surrogate (for all data)"); + DBG(UQ, + "Model exists, I am calling DeltaUQ surrogate [%ld %ld] -> (mu:[%ld " + "%ld], std:[%ld %ld])", + totalElements, + inputs.size(), + totalElements, + outputs.size(), + totalElements, + inputs.size()); surrogate->evaluate(totalElements, inputs, outputs, outputs_stdev); CALIPER(CALI_MARK_END("SURROGATE");) + // FIXME: We do something sub-optimal. We copy all the data from the GPU + // to the CPU and then we compute the predicate. Then we copy back the computed + // predicate to the device. We should avoid this unecessary back and forth. + bool *predicate = p_ml_acceptable; + if (surrogate->getModelResource() == AMSResourceType::DEVICE) { + predicate = rm.allocate(totalElements, AMSResourceType::HOST); + rm.copy(p_ml_acceptable, predicate); + } + + if (uqPolicy == AMSUQPolicy::DeltaUQ_Mean) { for (size_t i = 0; i < totalElements; ++i) { // Use double for increased precision, range in the calculation @@ -95,7 +116,7 @@ class UQ for (size_t dim = 0; dim < ndims; ++dim) mean += outputs_stdev[dim][i]; mean /= ndims; - p_ml_acceptable[i] = (mean < threshold); + predicate[i] = (mean < threshold); } } else if (uqPolicy == AMSUQPolicy::DeltaUQ_Max) { for (size_t i = 0; i < totalElements; ++i) { @@ -106,12 +127,17 @@ class UQ break; } - p_ml_acceptable[i] = is_acceptable; + predicate[i] = is_acceptable; } } else { THROW(std::runtime_error, "Invalid UQ policy"); } + if (surrogate->getModelResource() == AMSResourceType::DEVICE) { + rm.copy(predicate, p_ml_acceptable); + rm.deallocate(predicate, AMSResourceType::HOST); + } + for (int dim = 0; dim < ndims; ++dim) rm.deallocate(outputs_stdev[dim], AMSResourceType::HOST); CALIPER(CALI_MARK_END("DELTAUQ");) diff --git a/src/AMSlib/wf/basedb.hpp b/src/AMSlib/wf/basedb.hpp index 34e49d19..16ce5be2 100644 --- a/src/AMSlib/wf/basedb.hpp +++ b/src/AMSlib/wf/basedb.hpp @@ -226,7 +226,7 @@ class csvDB final : public FileDB /** * @brief Return the DB enumerationt type (File, Redis etc) */ - AMSDBType dbType() { return AMSDBType::CSV; }; + AMSDBType dbType() { return AMSDBType::DBCSV; }; /** * @brief Takes an input and an output vector each holding 1-D vectors data, and @@ -310,7 +310,7 @@ class hdf5DB final : public FileDB */ hid_t getDataSet(hid_t group, std::string dName, - const size_t Chunk = 32L * 1024L * 1024L) + const size_t Chunk = 256L * 1024L) { // Our datasets a.t.m are 1-D vectors const int nDims = 1; @@ -488,7 +488,7 @@ class hdf5DB final : public FileDB /** * @brief Return the DB enumerationt type (File, Redis etc) */ - AMSDBType dbType() { return AMSDBType::HDF5; }; + AMSDBType dbType() { return AMSDBType::DBHDF5; }; /** @@ -586,7 +586,7 @@ class RedisDB : public BaseDB /** * @brief Return the DB enumerationt type (File, Redis etc) */ - AMSDBType dbType() { return AMSDBType::REDIS; }; + AMSDBType dbType() { return AMSDBType::DBREDIS; }; inline std::string info() { return _redis->info(); } @@ -2339,7 +2339,7 @@ class RabbitMQDB final : public BaseDB /** * @brief Return the DB enumerationt type (File, Redis etc) */ - AMSDBType dbType() { return AMSDBType::RMQ; }; + AMSDBType dbType() { return AMSDBType::DBRMQ; }; void close() { @@ -2425,18 +2425,18 @@ class DBManager } switch (dbType) { - case AMSDBType::CSV: + case AMSDBType::DBCSV: return std::make_shared>(dbPath, rId); #ifdef __ENABLE_REDIS__ - case AMSDBType::REDIS: + case AMSDBType::DBREDIS: return std::make_shared>(dbPath, rId); #endif #ifdef __ENABLE_HDF5__ - case AMSDBType::HDF5: + case AMSDBType::DBHDF5: return std::make_shared>(dbPath, rId); #endif #ifdef __ENABLE_RMQ__ - case AMSDBType::RMQ: + case AMSDBType::DBRMQ: return std::make_shared>(dbPath, rId); #endif default: diff --git a/src/AMSlib/wf/cuda/utilities.cuh b/src/AMSlib/wf/cuda/utilities.cuh index 3969688a..90fe26d4 100644 --- a/src/AMSlib/wf/cuda/utilities.cuh +++ b/src/AMSlib/wf/cuda/utilities.cuh @@ -16,6 +16,7 @@ #include #include "wf/resource_manager.hpp" +#include "wf/debug.h" //#include //#include @@ -35,23 +36,10 @@ __device__ __inline__ int pow2i(int e) { return 1 << e; } inline void __cudaSafeCall(cudaError err, const char* file, const int line) { -#ifdef CUDA_ERROR_CHECK - if (cudaSuccess != err) { - fprintf(stderr, - "cudaSafeCall() failed at %s:%i : %s\n", + CFATAL(CUDA, (cudaSuccess != err), "cudaSafeCall() failed at %s:%i : %s\n", file, line, cudaGetErrorString(err)); - - fprintf(stdout, - "cudaSafeCall() failed at %s:%i : %s\n", - file, - line, - cudaGetErrorString(err)); - exit(-1); - } -#endif - return; } diff --git a/src/AMSlib/wf/debug.cpp b/src/AMSlib/wf/debug.cpp index d2379609..86e4a4ef 100644 --- a/src/AMSlib/wf/debug.cpp +++ b/src/AMSlib/wf/debug.cpp @@ -5,6 +5,11 @@ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception */ +#include "wf/debug.h" + +#if defined(__ENABLE_CUDA__) && defined(__ENABLE_TORCH__) +#include +#endif #include #include @@ -40,3 +45,62 @@ void memUsage(double& vm_usage, double& resident_set) vm_usage = vsize / 1024.0; resident_set = rss * page_size; } + +void dumpTorchDeviceStats() +{ +#if defined(__ENABLE_CUDA__) && defined(__ENABLE_TORCH__) + c10::cuda::CUDACachingAllocator::emptyCache(); + int curr_device = c10::cuda::current_device(); + c10::cuda::CUDACachingAllocator::DeviceStats stats = + c10::cuda::CUDACachingAllocator::getDeviceStats(curr_device); + + DBG(TorchDeviceStats, + "Current device according to torch has id : %d", + curr_device); + + for (auto S : stats.allocated_bytes) { + DBG(TorchDeviceStats, + "Allocated Current: %g (MBytes) Peak: %g (MBytes) Allocated: %G " + "(MBytes) Freed: " + "%g (MBytes)", + (double)(S.current) / (1024.0 * 1024.0), + (double)(S.peak) / (1024.0 * 1024.0), + (double)(S.allocated) / (1024.0 * 1024.0), + (double)(S.freed) / (1024.0 * 1024.0)); + } + + for (auto S : stats.reserved_bytes) { + DBG(TorchDeviceStats, + "Reserved Current: %g (MBytes) Peak: %g (MBytes) Allocated: %G " + "(MBytes) Freed: " + "%g (MBytes)", + (double)(S.current) / (1024.0 * 1024.0), + (double)(S.peak) / (1024.0 * 1024.0), + (double)(S.allocated) / (1024.0 * 1024.0), + (double)(S.freed) / (1024.0 * 1024.0)); + } + + for (auto S : stats.active_bytes) { + DBG(TorchDeviceStats, + "Active Current: %g (MBytes) Peak: %g (MBytes) Allocated: %G " + "(MBytes) Freed: " + "%g (MBytes)", + (double)(S.current) / (1024.0 * 1024.0), + (double)(S.peak) / (1024.0 * 1024.0), + (double)(S.allocated) / (1024.0 * 1024.0), + (double)(S.freed) / (1024.0 * 1024.0)); + } + + for (auto S : stats.inactive_split_bytes) { + DBG(TorchDeviceStats, + "Inactive Split Bytes Current: %g (MBytes) Peak: %g (MBytes) " + "Allocated: %G " + "(MBytes) Freed: " + "%g (MBytes)", + (double)(S.current) / (1024.0 * 1024.0), + (double)(S.peak) / (1024.0 * 1024.0), + (double)(S.allocated) / (1024.0 * 1024.0), + (double)(S.freed) / (1024.0 * 1024.0)); + } +#endif +} diff --git a/src/AMSlib/wf/debug.h b/src/AMSlib/wf/debug.h index 4927c32e..1bcbe99b 100644 --- a/src/AMSlib/wf/debug.h +++ b/src/AMSlib/wf/debug.h @@ -31,6 +31,8 @@ enum AMSVerbosity { void memUsage(double& vm_usage, double& resident_set); +void dumpTorchDeviceStats(); +void deviceMemoryInfo(size_t *free, size_t *total); inline std::atomic& getInfoLevelInternal() { @@ -80,7 +82,12 @@ inline uint32_t getVerbosityLevel() #define FATAL(id, ...) CFATAL(id, true, __VA_ARGS__) -#ifdef LIBAMS_VERBOSE +#define THROW(exception, msg) \ + throw exception(std::string(__FILE__) + ":" + std::to_string(__LINE__) + \ + " " + msg) + + +#ifdef AMS_DEBUG #define CWARNING(id, condition, ...) \ AMSPRINT(id, condition, AMSVerbosity::AMSWARNING, YEL, __VA_ARGS__) @@ -97,35 +104,54 @@ inline uint32_t getVerbosityLevel() #define DBG(id, ...) CDEBUG(id, true, __VA_ARGS__) -#define REPORT_MEM_USAGE(id, phase) \ - do { \ - double vm, rs; \ - size_t watermark, current_size, actual_size; \ - auto& rm = ams::ResourceManager::getInstance(); \ - memUsage(vm, rs); \ - DBG(id, "Memory usage at %s is VM:%g RS:%g", phase, vm, rs); \ - \ - for (int i = 0; i < AMSResourceType::RSEND; i++) { \ - if (rm.isActive((AMSResourceType)i)) { \ - rm.getAllocatorStats((AMSResourceType)i, \ - watermark, \ - current_size, \ - actual_size); \ - DBG(id, \ - "Allocator: %s HWM:%lu CS:%lu AS:%lu) ", \ - rm.getAllocatorName((AMSResourceType)i) \ - .c_str(), \ - watermark, \ - current_size, \ - actual_size); \ - } \ - } \ +#define REPORT_MEM_USAGE(id, phase) \ + do { \ + double vm, rs; \ + size_t watermark, current_size, actual_size; \ + auto& rm = ams::ResourceManager::getInstance(); \ + memUsage(vm, rs); \ + DBG(id, "Memory usage at %s is VM:%g RS:%g", phase, vm, rs); \ + \ + for (int i = 0; i < AMSResourceType::RSEND; i++) { \ + if (rm.isActive((AMSResourceType)i)) { \ + rm.getAllocatorStats((AMSResourceType)i, \ + watermark, \ + current_size, \ + actual_size); \ + DBG(id, \ + "Allocator (in MBytes): %s HWM:%g CS:%g AS:%g) ", \ + rm.getAllocatorName((AMSResourceType)i).c_str(), \ + (double)(watermark) / (1024.0 * 1024.0), \ + (double)(current_size) / (1024.0 * 1024.0), \ + (double)(actual_size) / (1024.0 * 1024.0)); \ + } \ + } \ + dumpTorchDeviceStats(); \ + size_t free, total; \ + deviceMemoryInfo(&free, &total); \ + DBG(id, \ + "Device Memory Usage (cuda-driver) Used: %g MBytes, Free: %g MBytes", \ + ((double)(total - free)) / (1024.0 * 1024.0), \ + (double)(free) / (1024.0 * 1024.0)); \ } while (0); -#define THROW(exception, msg) \ - throw exception(std::string(__FILE__) + ":" + std::to_string(__LINE__) + \ - " " + msg) -#else // LIBAMS_VERBOSE is disabled +#ifdef __ENABLE_CUDA__ +// NOTE: Regardless of condition we synchronize. We only emit a message based on condition. +#define _CAMSDebugDeviceSync(id, condition, fn, ln, ...) \ + do{ \ + AMSDeviceSync(fn, ln); \ + CDEBUG(id, condition, __VA_ARGS__) \ + }while(0); + +#define CAMSDebugDeviceSync(id, condition, ...) _CAMSDebugDeviceSync(id, condition, __FILE__, __LINE__, __VA_ARGS__) +#define AMSDebugDeviceSync(id, ...) _CAMSDebugDeviceSync(id, true, __FILE__, __LINE__, __VA_ARGS__) +#else +#define CAMSDebugDeviceSync(id, condition, ...) +#define AMSDebugDeviceSync(id, ...) +#endif + + +#else // LIBAMS_DEBUG is disabled #define CWARNING(id, condition, ...) #define WARNING(id, ...) @@ -138,7 +164,11 @@ inline uint32_t getVerbosityLevel() #define DBG(id, ...) +#define REPORT_MEM_USAGE(id, phase) \ + +#define CAMSDebugDeviceSync(id, condition, ...) +#define AMSDebugDeviceSync(id, ...) -#endif // LIBAMS_VERBOSE +#endif // AMS_DEBUG -#endif // _OMPTARGET_DEBUG_H +#endif // __AMS_DEBUG__ diff --git a/src/AMSlib/wf/device.hpp b/src/AMSlib/wf/device.hpp index 73b784c0..e1c6cc10 100644 --- a/src/AMSlib/wf/device.hpp +++ b/src/AMSlib/wf/device.hpp @@ -153,6 +153,13 @@ void deviceCheckErrors(const char *file, const int line) } +void deviceMemoryInfo(size_t *free, size_t *total) +{ +#ifdef __ENABLE_CUDA__ + cudaMemGetInfo(free, total); +#endif +} + #ifdef __ENABLE_CUDA__ #include @@ -184,7 +191,6 @@ __global__ void random_uq_device(int seed, uq_flags[id] = (x <= acceptable_error); } - #include PERFFASPECT() inline void DtoDMemcpy(void *dest, void *src, size_t nBytes) @@ -209,6 +215,11 @@ inline void DtoHMemcpy(void *dest, void *src, size_t nBytes) { cudaMemcpy(dest, src, nBytes, cudaMemcpyDeviceToHost); } + +inline void AMSDeviceSync(const char* file, const int line){ + __cudaSafeCall(cudaDeviceSynchronize(), file, line); +} + #else PERFFASPECT() inline void DtoDMemcpy(void *dest, void *src, size_t nBytes) @@ -236,6 +247,10 @@ inline void DtoHMemcpy(void *dest, void *src, size_t nBytes) std::cerr << "DtoH Memcpy Not Enabled" << std::endl; exit(-1); } + +inline void AMSDeviceSync(const char* file, const int line){ + std::cerr << "GPU Not enabled" << std::endl; +} #endif #endif diff --git a/src/AMSlib/wf/redist_load.hpp b/src/AMSlib/wf/redist_load.hpp index bbedf3b5..3f3d05f8 100644 --- a/src/AMSlib/wf/redist_load.hpp +++ b/src/AMSlib/wf/redist_load.hpp @@ -103,77 +103,10 @@ class AMSLoadBalancer /** @brief The memory location of the data (GPU (DEVICE), CPU (HOST) ) */ AMSResourceType resource; -private: - /** @brief Computes the number of balanced elements each process will gather and initializes - * memory structures. - * @param[in] numIn The number of input dimensions - * @param[in] numOut The number of input dimensions - * @param[in] resource The resource type to allocate data in. - - * @details The function computes the total number of elements every rank will need to balance. - * It initializes the 'dataElements', 'displs' on the root node and the localLoad, balancedLoad - * across all ranks. - */ - void init(int numIn, int numOut, AMSResourceType resource) - { - auto& rm = ams::ResourceManager::getInstance(); - // We need to store information - if (rId == root) { - dataElements = - rm.allocate(worldSize, AMSResourceType::HOST); - displs = rm.allocate(worldSize + 1, - AMSResourceType::HOST); - } - - // Gather the the number of items from each rank - int rc = MPI_Gather(reinterpret_cast(&localLoad), - 1, - MPI_INT, - reinterpret_cast(dataElements), - 1, - MPI_INT, - root, - Comm); - CFATAL(LoadBalance, rc != MPI_SUCCESS, "Cannot gather per rank sizes") - - // Populate displacement array - if (rId == 0) { - globalLoad = 0; - displs[0] = static_cast(0); - for (size_t i = 0ul; i < worldSize; ++i) { - displs[i + 1] = dataElements[i] + displs[i]; - } - globalLoad = displs[worldSize]; - } - - balancedLoad = computeBalanceLoad(); - - if (rId == root) { - balancedElements = - rm.ResourceManager::allocate(worldSize, AMSResourceType::HOST); - balancedDispls = - rm.ResourceManager::allocate(worldSize, AMSResourceType::HOST); - for (int i = 0; i < worldSize; i++) { - balancedElements[i] = (globalLoad / worldSize) + - static_cast(i < (globalLoad % worldSize)); - if (i != 0) - balancedDispls[i] = balancedElements[i] + balancedDispls[i - 1]; - else - balancedDispls[i] = 0; - } - } - - for (int i = 0; i < numIn; i++) { - distInputs.push_back( - rm.allocate(balancedLoad, resource)); - } - - for (int i = 0; i < numOut; i++) { - distOutputs.push_back( - rm.allocate(balancedLoad, resource)); - } - } + /** @brief Check if load balancer is initialized **/ + bool _initialized; +private: /** @brief Computes the number of elements every rank will receive after balancing. * @returns the number of elements computed by this rank. **/ @@ -266,10 +199,11 @@ class AMSLoadBalancer AMSResourceType resource) { FPTypeValue *temp_data; - auto& rm = ams::ResourceManager::getInstance(); + auto &rm = ams::ResourceManager::getInstance(); if (rId == root) { - temp_data = rm.ResourceManager::allocate(globalLoad, resource); + temp_data = + rm.ResourceManager::allocate(globalLoad, resource); } for (int i = 0; i < src.size(); i++) { @@ -304,17 +238,8 @@ class AMSLoadBalancer * @param[in] worldSize The total number of ranks in respect to the Comm communicator. * @param[in] localLoad The number of elements this rank has to compute originally (before load balance). * @param[in] Comm The MPI communicator. - * @param[in] numIn The number of input vectors to be balanced. - * @param[in] numOut The number of output vectors to be balanced. - * @param[in] resource The location of data allocations (CPU|GPU). - */ - AMSLoadBalancer(int rId, - int worldSize, - int localLoad, - MPI_Comm comm, - int numIn, - int numOut, - AMSResourceType resource) + */ + AMSLoadBalancer(int rId, int worldSize, int localLoad, MPI_Comm comm) : rId(rId), worldSize(worldSize), localLoad(localLoad), @@ -324,23 +249,27 @@ class AMSLoadBalancer dataElements(nullptr), balancedElements(nullptr), balancedDispls(nullptr), - resource(resource) + resource(resource), + _initialized(false) { - init(numIn, numOut, resource); } /** @brief deallocates all objects of this load balancing transcation */ ~AMSLoadBalancer() { - auto& rm = ams::ResourceManager::getInstance(); - CINFO(LoadBalance, root==rId, "Total data %d Data per rank %d", globalLoad, balancedLoad); + if (!_initialized) return; + + auto &rm = ams::ResourceManager::getInstance(); + CINFO(LoadBalance, + root == rId, + "Total data %d Data per rank %d", + globalLoad, + balancedLoad); if (displs) rm.deallocate(displs, AMSResourceType::HOST); - if (dataElements) - rm.deallocate(dataElements, AMSResourceType::HOST); + if (dataElements) rm.deallocate(dataElements, AMSResourceType::HOST); if (balancedElements) rm.deallocate(balancedElements, AMSResourceType::HOST); - if (balancedDispls) - rm.deallocate(balancedDispls, AMSResourceType::HOST); + if (balancedDispls) rm.deallocate(balancedDispls, AMSResourceType::HOST); for (int i = 0; i < distOutputs.size(); i++) rm.deallocate(distOutputs[i], resource); @@ -350,6 +279,73 @@ class AMSLoadBalancer } }; + /** @brief Computes the number of balanced elements each process will gather and initializes + * memory structures. + * @param[in] numIn The number of input dimensions + * @param[in] numOut The number of input dimensions + * @param[in] resource The resource type to allocate data in. + + * @details The function computes the total number of elements every rank will need to balance. + * It initializes the 'dataElements', 'displs' on the root node and the localLoad, balancedLoad + * across all ranks. + */ + void init(int numIn, int numOut, AMSResourceType resource) + { + auto &rm = ams::ResourceManager::getInstance(); + // We need to store information + if (rId == root) { + dataElements = rm.allocate(worldSize, AMSResourceType::HOST); + displs = rm.allocate(worldSize + 1, AMSResourceType::HOST); + } + + // Gather the the number of items from each rank + int rc = MPI_Gather(reinterpret_cast(&localLoad), + 1, + MPI_INT, + reinterpret_cast(dataElements), + 1, + MPI_INT, + root, + Comm); + CFATAL(LoadBalance, rc != MPI_SUCCESS, "Cannot gather per rank sizes") + + // Populate displacement array + if (rId == 0) { + globalLoad = 0; + displs[0] = static_cast(0); + for (size_t i = 0ul; i < worldSize; ++i) { + displs[i + 1] = dataElements[i] + displs[i]; + } + globalLoad = displs[worldSize]; + } + + balancedLoad = computeBalanceLoad(); + + if (rId == root) { + balancedElements = + rm.ResourceManager::allocate(worldSize, AMSResourceType::HOST); + balancedDispls = + rm.ResourceManager::allocate(worldSize, AMSResourceType::HOST); + for (int i = 0; i < worldSize; i++) { + balancedElements[i] = (globalLoad / worldSize) + + static_cast(i < (globalLoad % worldSize)); + if (i != 0) + balancedDispls[i] = balancedElements[i] + balancedDispls[i - 1]; + else + balancedDispls[i] = 0; + } + } + + for (int i = 0; i < numIn; i++) { + distInputs.push_back(rm.allocate(balancedLoad, resource)); + } + + for (int i = 0; i < numOut; i++) { + distOutputs.push_back(rm.allocate(balancedLoad, resource)); + } + _initialized = true; + } + /** * @brief Reverse load balance in respect to the output vectors. * @param[out] outputs The vector to store all the output values gathered from their compute (remote) ranks. diff --git a/src/AMSlib/wf/workflow.hpp b/src/AMSlib/wf/workflow.hpp index 2f1c8656..6003f53d 100644 --- a/src/AMSlib/wf/workflow.hpp +++ b/src/AMSlib/wf/workflow.hpp @@ -63,7 +63,7 @@ class AMSWorkflow std::shared_ptr> DB; /** @brief The type of the database we will use (HDF5, CSV, etc) */ - AMSDBType dbType = AMSDBType::None; + AMSDBType dbType = AMSDBType::DDDBNone; /** @brief The process id. For MPI runs this is the rank */ const int rId; @@ -141,7 +141,6 @@ class AMSWorkflow DB->store(actualElems, hInputs, hOutputs); } rm.deallocate(pPtr, AMSResourceType::PINNED); - return; } @@ -149,7 +148,7 @@ class AMSWorkflow AMSWorkflow() : AppCall(nullptr), DB(nullptr), - dbType(AMSDBType::None), + dbType(AMSDBType::DDDBNone), appDataLoc(AMSResourceType::HOST), ePolicy(AMSExecPolicy::UBALANCED) { @@ -251,6 +250,7 @@ class AMSWorkflow { CALIPER(CALI_MARK_BEGIN("AMSEvaluate");) + CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-External Errors") CDEBUG(Workflow, rId == 0, "Entering Evaluate " @@ -276,13 +276,17 @@ class AMSWorkflow totalElements, reinterpret_cast(origInputs.data()), reinterpret_cast(origOutputs.data())); + + CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-AppCall Errors") CALIPER(CALI_MARK_END("PHYSICS MODULE");) if (DB) { CALIPER(CALI_MARK_BEGIN("DBSTORE");) Store(totalElements, tmpIn, origOutputs); + CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-Store Errors") CALIPER(CALI_MARK_END("DBSTORE");) } CALIPER(CALI_MARK_END("AMSEvaluate");) + CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-Evaluate Errors") return; } @@ -300,6 +304,7 @@ class AMSWorkflow CALIPER(CALI_MARK_BEGIN("UQ_MODULE");) UQModel->evaluate(totalElements, origInputs, origOutputs, p_ml_acceptable); CALIPER(CALI_MARK_END("UQ_MODULE");) + CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-UQ/Surrogate Errors") DBG(Workflow, "Computed Predicates") @@ -324,6 +329,7 @@ class AMSWorkflow const long packedElements = data_handler::pack( appDataLoc, predicate, totalElements, origInputs, packedInputs); CALIPER(CALI_MARK_END("PACK");) + CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-Pack Errors") // Pointer values which store output data values // to be computed using the eos function. @@ -341,8 +347,9 @@ class AMSWorkflow #ifdef __ENABLE_MPI__ CALIPER(CALI_MARK_BEGIN("LOAD BALANCE MODULE");) AMSLoadBalancer lBalancer( - rId, wSize, packedElements, Comm, inputDim, outputDim, appDataLoc); + rId, wSize, packedElements, Comm); if (ePolicy == AMSExecPolicy::BALANCED && Comm) { + lBalancer.init(inputDim, outputDim, appDataLoc); lBalancer.scatterInputs(packedInputs, appDataLoc); iPtr = reinterpret_cast(lBalancer.inputs()); oPtr = reinterpret_cast(lBalancer.outputs()); @@ -351,12 +358,14 @@ class AMSWorkflow CALIPER(CALI_MARK_END("LOAD BALANCE MODULE");) #endif + // ---- 3b: call the physics module and store in the data base if (packedElements > 0) { CALIPER(CALI_MARK_BEGIN("PHYSICS MODULE");) AppCall(probDescr, lbElements, iPtr, oPtr); CALIPER(CALI_MARK_END("PHYSICS MODULE");) } + CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-AppCall Errors") #ifdef __ENABLE_MPI__ CALIPER(CALI_MARK_BEGIN("LOAD BALANCE MODULE");) @@ -372,6 +381,7 @@ class AMSWorkflow data_handler::unpack( appDataLoc, predicate, totalElements, packedOutputs, origOutputs); CALIPER(CALI_MARK_END("UNPACK");) + CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-UnPack Errors") DBG(Workflow, "Finished physics evaluation") @@ -383,6 +393,7 @@ class AMSWorkflow Store(packedElements, packedInputs, packedOutputs); CALIPER(CALI_MARK_END("DBSTORE");) } + CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-Store Errors") // ----------------------------------------------------------------- // Deallocate temporal data @@ -405,6 +416,7 @@ class AMSWorkflow REPORT_MEM_USAGE(Workflow, "End") CALIPER(CALI_MARK_END("AMSEvaluate");) + CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-Evaluate Errors") } }; diff --git a/tests/AMSlib/CMakeLists.txt b/tests/AMSlib/CMakeLists.txt index 5385e079..e9290f3a 100644 --- a/tests/AMSlib/CMakeLists.txt +++ b/tests/AMSlib/CMakeLists.txt @@ -5,9 +5,9 @@ function (BUILD_TEST exe source) add_executable(${exe} ${source}) - target_include_directories(${exe} PRIVATE "${PROJECT_SOURCE_DIR}/src/AMSlib/" umpire ${caliper_INCLUDE_DIR} ${MPI_INCLUDE_PATH}) + target_include_directories(${exe} PRIVATE "${PROJECT_SOURCE_DIR}/src/AMSlib/" umpire ${AMS_APP_INCLUDES} ${caliper_INCLUDE_DIR} ${MPI_INCLUDE_PATH}) target_link_directories(${exe} PRIVATE ${AMS_APP_LIB_DIRS}) - target_link_libraries(${exe} PRIVATE AMS ${AMS_APP_LIBRARIES}) + target_link_libraries(${exe} PRIVATE AMS ${AMS_APP_LIBRARIES} ${AMS_TORCH_LIBRARY}) target_compile_definitions(${exe} PRIVATE ${AMS_APP_DEFINES}) if (WITH_CUDA) @@ -39,6 +39,19 @@ if (WITH_TORCH) add_test(NAME AMSExampleSingleRandomUQ::HOST COMMAND ams_example --precision single --uqtype random -S ${CMAKE_CURRENT_SOURCE_DIR}/debug_model.pt -e 100) add_test(NAME AMSExampleDoubleRandomUQ::HOST COMMAND ams_example --precision double --uqtype random -S ${CMAKE_CURRENT_SOURCE_DIR}/debug_model.pt -e 100) + # UQ Tests + BUILD_TEST(ams_delta_uq_test ams_uq_model.cpp) + if (WITH_TORCH) + add_test(NAME AMSDeltaUQDoubleMean::HOST COMMAND ams_delta_uq_test 0 ${CMAKE_CURRENT_SOURCE_DIR}/torch.duq.cuda "double" 8 9 3 0.0) + add_test(NAME AMSDeltaUQDoubleMax::HOST COMMAND ams_delta_uq_test 0 ${CMAKE_CURRENT_SOURCE_DIR}/torch.duq.cuda "double" 8 9 4 0.0) + + if (WITH_CUDA) + add_test(NAME AMSDeltaUQDoubleMean::DEVICE COMMAND ams_delta_uq_test 1 ${CMAKE_CURRENT_SOURCE_DIR}/torch.duq.cuda "double" 8 9 3 0.0) + add_test(NAME AMSDeltaUQDoubleMax::DEVICE COMMAND ams_delta_uq_test 1 ${CMAKE_CURRENT_SOURCE_DIR}/torch.duq.cuda "double" 8 9 4 0.0) + endif() + endif() + #TODO Add tests with cpu model + BUILD_TEST(ams_update_model ams_update_model.cpp) ADDTEST(ams_update_model AMSUpdateModelDouble "double" ${CMAKE_CURRENT_SOURCE_DIR}/ConstantZeroModel_cpu.pt ${CMAKE_CURRENT_SOURCE_DIR}/ConstantOneModel_cpu.pt) endif() diff --git a/tests/AMSlib/ams_uq_model.cpp b/tests/AMSlib/ams_uq_model.cpp new file mode 100644 index 00000000..4a7262c8 --- /dev/null +++ b/tests/AMSlib/ams_uq_model.cpp @@ -0,0 +1,77 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#define SIZE (32L * 1024L + 3L) + +template +void model(UQ &model, + AMSResourceType resource, + int num_inputs, + int num_outputs) +{ + std::vector inputs; + std::vector outputs; + auto &ams_rm = ams::ResourceManager::getInstance(); + + for (int i = 0; i < num_inputs; i++) + inputs.push_back(ams_rm.allocate(SIZE, resource)); + + for (int i = 0; i < num_outputs; i++) + outputs.push_back(ams_rm.allocate(SIZE, resource)); + + bool *predicates = ams_rm.allocate(SIZE, resource); + + std::cout << "We are calling evaluate\n"; + model.evaluate(SIZE, inputs, outputs, predicates); + + + for (int i = 0; i < num_inputs; i++) + ams_rm.deallocate(const_cast(inputs[i]), resource); + + for (int i = 0; i < num_outputs; i++) + ams_rm.deallocate(outputs[i], resource); + + ams_rm.deallocate(predicates, resource); +} + + +int main(int argc, char *argv[]) +{ + using namespace ams; + auto &ams_rm = ResourceManager::getInstance(); + int use_device = std::atoi(argv[1]); + char *model_path = argv[2]; + char *data_type = argv[3]; + int num_inputs = std::atoi(argv[4]); + int num_outputs = std::atoi(argv[5]); + const AMSUQPolicy uq_policy = static_cast(std::atoi(argv[6])); + float threshold = std::atof(argv[7]); + + std::cout << "Executing on device " << use_device << "\n"; + + AMSResourceType resource = AMSResourceType::HOST; + if (use_device == 1) { + resource = AMSResourceType::DEVICE; + } + + ams_rm.init(); + + + if (std::strcmp("double", data_type) == 0) { + UQ UQModel(resource, uq_policy, nullptr, -1, model_path, threshold); + model(UQModel, resource, num_inputs, num_outputs); + } else if (std::strcmp("single", data_type) == 0) { + UQ UQModel(resource, uq_policy, nullptr, -1, model_path, threshold); + model(UQModel, resource, num_inputs, num_outputs); + } + + return 0; +} diff --git a/tests/AMSlib/torch.duq b/tests/AMSlib/torch.duq new file mode 100644 index 00000000..58ee24fc Binary files /dev/null and b/tests/AMSlib/torch.duq differ diff --git a/tests/AMSlib/torch.duq.cuda b/tests/AMSlib/torch.duq.cuda new file mode 100644 index 00000000..1a80a5a1 Binary files /dev/null and b/tests/AMSlib/torch.duq.cuda differ