diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5b8d9b4d..6443aed2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -39,6 +39,8 @@ option(WITH_TORCH_DEBUG "Compute RMSE of Surrogate Model and Physics Module"
option(WITH_TESTS "Compile tests" OFF)
option(WITH_REDIS "Use REDIS as a database back end" OFF)
option(WITH_HDF5 "Use HDF5 as a database back end" OFF)
+option(HDF5_USE_STATIC_LIBRARIES "Use static HDF5." OFF)
+set(HDF5_WITH_ZLIB "" CACHE FILETYPE "Use the following zlib for HDF5")
option(WITH_RMQ "Use RabbitMQ as a database back end (require a reachable and running RabbitMQ server service)" OFF)
option(WITH_AMS_DEBUG "Enable verbose messages" OFF)
option(WITH_PERFFLOWASPECT "Use PerfFlowAspect for Profiling" OFF)
@@ -46,6 +48,7 @@ option(WITH_WORKFLOW "Install python drivers used by the outer workflow" O
option(WITH_AMS_LIB "Install C++ library to support scientific applications" ON)
option(WITH_ADIAK "Use Adiak for recording metadata" OFF)
option(BUILD_SHARED_LIBS "Build using shared libraries" ON)
+option(EXCLUDE_STATIC_LIBS "Exclude static libs from the linking line" OFF)
if (WITH_MPI)
# SET(CMAKE_CXX_COMPILER "${MPI_CXX_COMPILER}" CACHE FILEPATH "CXX compiler overridden with MPI C++ wrapper")
@@ -69,6 +72,7 @@ if (WITH_CUDA)
if (BUILD_SHARED_LIBS)
set(CUDA_RUNTIME_LIBRARY "Shared")
else()
+ set(HDF5_USE_STATIC_LIBRARIES ON)
set(CUDA_RUNTIME_LIBRARY "Static")
endif()
@@ -91,7 +95,7 @@ if (WITH_CALIPER)
endif()
if (WITH_AMS_DEBUG)
- list(APPEND AMS_APP_DEFINES "-DLIBAMS_VERBOSE")
+ list(APPEND AMS_APP_DEFINES "-DAMS_DEBUG")
endif()
# ------------------------------------------------------------------------------
@@ -132,14 +136,17 @@ endif() # WITH_REDIS
if (WITH_HDF5)
if (HDF5_USE_STATIC_LIBRARIES)
- find_package(HDF5 NAMES hdf5 COMPONENTS C static NO_DEFAULT_PATH PATHS ${AMS_HDF5_DIR} ${AMS_HDF5_DIR}/share/cmake)
- list(APPEND AMS_APP_LIBRARIES ${HDF5_C_STATIC_LIBRARY})
- message(STATUS "HDF5 Static Library : ${HDF5_C_STATIC_LIBRARY}")
-else()
- find_package(HDF5 NAMES hdf5 COMPONENTS C shared NO_DEFAULT_PATH PATHS ${AMS_HDF5_DIR} ${AMS_HDF5_DIR}/share/cmake)
- list(APPEND AMS_APP_LIBRARIES ${HDF5_C_SHARED_LIBRARY})
- message(STATUS "HDF5 Shared Library : ${HDF5_C_SHARED_LIBRARY}")
-endif()
+ find_package(HDF5 NAMES hdf5 COMPONENTS C static NO_DEFAULT_PATH PATHS ${AMS_HDF5_DIR} ${AMS_HDF5_DIR}/share/cmake)
+ list(APPEND AMS_APP_LIBRARIES ${HDF5_C_STATIC_LIBRARY})
+ message(STATUS "HDF5 Static Library : ${HDF5_C_STATIC_LIBRARY}")
+ else()
+ find_package(HDF5 NAMES hdf5 COMPONENTS C shared NO_DEFAULT_PATH PATHS ${AMS_HDF5_DIR} ${AMS_HDF5_DIR}/share/cmake)
+ list(APPEND AMS_APP_LIBRARIES ${HDF5_C_SHARED_LIBRARY})
+ message(STATUS "HDF5 Shared Library : ${HDF5_C_SHARED_LIBRARY}")
+ endif()
+ if (NOT HDF5_WITH_ZLIB STREQUAL "")
+ list(APPEND AMS_APP_LIBRARIES ${HDF5_WITH_ZLIB})
+ endif()
list(APPEND AMS_APP_INCLUDES ${HDF5_INCLUDE_DIR})
list(APPEND AMS_APP_DEFINES "-D__ENABLE_HDF5__")
message(STATUS "HDF5 Include directories: ${HDF5_INCLUDE_DIR}")
@@ -191,16 +198,50 @@ if (WITH_TORCH)
find_package(Torch REQUIRED)
# This is annoying, torch populates all my cuda flags
# and resets them
- set(CMAKE_CUDA_FLAGS "")
+ # set(CMAKE_CUDA_FLAGS "")
set(CMAKE_CUDA_ARCHITECTURES ON)
- list(APPEND AMS_APP_INCLUDES "${TORCH_INCLUDE_DIRS}")
- list(APPEND AMS_APP_LIBRARIES "${TORCH_LIBRARIES}")
-
+ #get_target_property(torch_interface_system_includes
+ # torch INTERFACE_SYSTEM_INCLUDE_DIRECTORIES)
+ #if ( torch_interface_system_includes )
+ # list(APPEND AMS_APP_INCLUDES ${torch_interface_system_includes})
+ #endif()
+ #
+ #get_target_property(torch_interface_includes
+ # torch INTERFACE_INCLUDE_DIRECTORIES)
+ #if ( torch_interface_includes )
+ # list(APPEND AMS_APP_INCLUDES ${torch_interface_includes})
+ #endif()
+ #
+ #get_target_property(torch_interface_defines
+ # torch INTERFACE_COMPILE_DEFINITIONS)
+ #if ( troch_interface_defines )
+ # list(APPEND AMS_APP_DEFINES ${torch_interface_defines})
+ #endif()
+ #
+ #get_target_property(torch_interface_compile_options
+ # torch INTERFACE_COMPILE_OPTIONS)
+ #if ( torch_interface_compile_options )
+ # list(APPEND AMS_APP_DEFINES ${torch_interface_compile_options})
+ #endif()
+ #
+ #get_target_property(_interface_link_directories
+ # ${arg_FROM} INTERFACE_LINK_DIRECTORIES)
+ #if ( _interface_link_directories )
+ # target_link_directories( ${arg_TO} ${_scope} ${_interface_link_directories})
+ #endif()
+ #
+ #get_target_property(torch_interface_link_libraries
+ # torch INTERFACE_LINK_LIBRARIES)
+ #if ( torch_interface_link_libraries )
+ # list(APPEND AMS_APP_LIBRARIES "${torch_interface_link_libraries}")
+ #endif()
+
+ list(APPEND AMS_TORCH_LIBRARY torch)
list(APPEND AMS_APP_DEFINES "-D__ENABLE_TORCH__")
- set(BLA_VENDER OpenBLAS)
- find_package(BLAS REQUIRED)
- list(APPEND AMS_APP_LIBRARIES "${BLAS_LIBRARIES}")
+ #set(BLA_VENDER OpenBLAS)
+ #find_package(BLAS REQUIRED)
+ #list(APPEND AMS_APP_LIBRARIES "${BLAS_LIBRARIES}")
endif()
# ------------------------------------------------------------------------------
@@ -253,6 +294,98 @@ if (WITH_PERFFLOWASPECT)
endif()
+macro(inherit_target_nostatic)
+ set(options)
+ set(singleValueArgs TO FROM OBJECT)
+ set(multiValueArgs)
+
+ # Parse the arguments
+ cmake_parse_arguments(arg "${options}" "${singleValueArgs}"
+ "${multiValueArgs}" ${ARGN} )
+
+ # Check arguments
+ if ( NOT DEFINED arg_TO )
+ message( FATAL_ERROR "Must provide a TO argument to the 'blt_inherit_target' macro" )
+ endif()
+
+ if ( NOT DEFINED arg_FROM )
+ message( FATAL_ERROR "Must provide a FROM argument to the 'blt_inherit_target' macro" )
+ endif()
+
+ set(_scope INTERFACE)
+
+ get_target_property(_interface_system_includes
+ ${arg_FROM} INTERFACE_SYSTEM_INCLUDE_DIRECTORIES)
+ if ( _interface_system_includes )
+ target_include_directories(${arg_TO} SYSTEM ${_scope} ${_interface_system_includes})
+ endif()
+
+ get_target_property(_interface_includes
+ ${arg_FROM} INTERFACE_INCLUDE_DIRECTORIES)
+ if ( _interface_includes )
+ target_include_directories(${arg_TO} ${_scope} ${_interface_includes})
+ endif()
+
+ get_target_property(_interface_defines
+ ${arg_FROM} INTERFACE_COMPILE_DEFINITIONS)
+ if ( _interface_defines )
+ target_compile_definitions( ${arg_TO} ${_scope} ${_interface_defines})
+ endif()
+
+ if( ${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.13.0" )
+ get_target_property(_interface_link_options
+ ${arg_FROM} INTERFACE_LINK_OPTIONS)
+ if ( _interface_link_options )
+ target_link_options( ${arg_TO} ${_scope} ${_interface_link_options})
+ endif()
+ endif()
+
+ get_target_property(_interface_compile_options
+ ${arg_FROM} INTERFACE_COMPILE_OPTIONS)
+ if ( _interface_compile_options )
+ target_compile_options( ${arg_TO} ${_scope} ${_interface_compile_options})
+ endif()
+
+ if ( NOT arg_OBJECT )
+ #get_target_property(_interface_link_directories
+ # ${arg_FROM} INTERFACE_LINK_DIRECTORIES)
+ #if ( _interface_link_directories )
+ # target_link_directories( ${arg_TO} ${_scope} ${_interface_link_directories})
+ #endif()
+
+ #get_target_property(_interface_link_libraries
+ # ${arg_FROM} INTERFACE_LINK_LIBRARIES)
+ #if ( _interface_link_libraries )
+ # target_link_libraries( ${arg_TO} ${_scope} ${_interface_link_libraries})
+ #endif()
+ endif()
+
+endmacro(inherit_target_nostatic)
+
+
+if (EXCLUDE_STATIC_LIBS)
+ set(NONSTATIC_AMS_APP_INTERFACE_LIBRARIES "")
+ set(NONSTATIC_AMS_APP_PRIVATE_LIBRARIES "")
+ foreach (THIS_LIB ${AMS_APP_LIBRARIES})
+ list(APPEND NONSTATIC_AMS_APP_INTERFACE_LIBRARIES ${THIS_LIB})
+ if (TARGET ${THIS_LIB})
+ get_target_property(target_type ${THIS_LIB} TYPE)
+ if (NOT target_type STREQUAL STATIC_LIBRARY)
+ list(APPEND NONSTATIC_AMS_APP_PRIVATE_LIBRARIES ${THIS_LIB})
+ else()
+ add_library("${THIS_LIB}::nostatic" INTERFACE IMPORTED)
+ inherit_target_nostatic(TO "${THIS_LIB}::nostatic" FROM ${THIS_LIB})
+ list(APPEND NONSTATIC_AMS_APP_PRIVATE_LIBRARIES "${THIS_LIB}::nostatic")
+ endif()
+ else()
+ get_filename_component(THIS_EXT ${THIS_LIB} EXT)
+ if (NOT ".a" STREQUAL "${THIS_EXT}")
+ list(APPEND NONSTATIC_AMS_APP_PRIVATE_LIBRARIES ${THIS_LIB})
+ endif()
+ endif()
+ endforeach()
+endif()
+
add_subdirectory(src)
# ------------------------------------------------------------------------------
diff --git a/examples/main.cpp b/examples/main.cpp
index ed4f7a87..67ce65bb 100644
--- a/examples/main.cpp
+++ b/examples/main.cpp
@@ -179,13 +179,13 @@ int run(const char *device_name,
CALIPER(CALI_MARK_BEGIN("Setup");)
const bool use_device = std::strcmp(device_name, "cpu") != 0;
- AMSDBType dbType = AMSDBType::None;
+ AMSDBType dbType = AMSDBType::DDDBNone;
if (std::strcmp(db_type, "csv") == 0) {
- dbType = AMSDBType::CSV;
+ dbType = AMSDBType::DBCSV;
} else if (std::strcmp(db_type, "hdf5") == 0) {
- dbType = AMSDBType::HDF5;
+ dbType = AMSDBType::DBHDF5;
} else if (std::strcmp(db_type, "rmq") == 0) {
- dbType = AMSDBType::RMQ;
+ dbType = AMSDBType::DBRMQ;
}
AMSUQPolicy uq_policy;
diff --git a/src/AMSlib/CMakeLists.txt b/src/AMSlib/CMakeLists.txt
index 8341439d..e00c95e4 100644
--- a/src/AMSlib/CMakeLists.txt
+++ b/src/AMSlib/CMakeLists.txt
@@ -42,7 +42,13 @@ target_include_directories(AMS PUBLIC
$)
target_include_directories(AMS PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_link_directories(AMS PUBLIC ${AMS_APP_LIB_DIRS})
-target_link_libraries(AMS PUBLIC ${AMS_APP_LIBRARIES} stdc++fs)
+if (EXCLUDE_STATIC_LIBS)
+ target_link_libraries(AMS PRIVATE ${NONSTATIC_AMS_APP_PRIVATE_LIBRARIES} ${AMS_TORCH_LIBRARY} stdc++fs stdc++ /usr/tce/packages/intel/intel-2022.1.0/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libintlc.so m)
+ target_link_libraries(AMS INTERFACE ${NONSTATIC_AMS_APP_INTERFACE_LIBRARIES} stdc++fs stdc++ /usr/tce/packages/intel/intel-2022.1.0/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libintlc.so m)
+ target_link_options(AMS PRIVATE "-Wl,--unresolved-symbols=ignore-all")
+else()
+ target_link_libraries(AMS PUBLIC ${AMS_APP_LIBRARIES} stdc++fs stdc++ /usr/tce/packages/intel/intel-2022.1.0/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libintlc.so m)
+endif()
#-------------------------------------------------------------------------------
# create the configuration header file with the respective information
@@ -75,7 +81,7 @@ install(TARGETS AMS
DESTINATION lib)
install(EXPORT AMSTargets
- FILE AMS.cmake
+ FILE AMSConfig.cmake
DESTINATION lib/cmake/AMS)
install(FILES ${PROJECT_BINARY_DIR}/include/AMS.h DESTINATION include)
diff --git a/src/AMSlib/include/AMS.h b/src/AMSlib/include/AMS.h
index 64488948..51e6a8a8 100644
--- a/src/AMSlib/include/AMS.h
+++ b/src/AMSlib/include/AMS.h
@@ -55,7 +55,7 @@ typedef enum {
typedef enum { UBALANCED = 0, BALANCED } AMSExecPolicy;
-typedef enum { None = 0, CSV, REDIS, HDF5, RMQ } AMSDBType;
+typedef enum { DDDBNone = 0, DBCSV, DBREDIS, DBHDF5, DBRMQ } AMSDBType;
// TODO: create a cleaner interface that separates UQ type (FAISS, DeltaUQ) with policy (max, mean).
enum struct AMSUQPolicy {
diff --git a/src/AMSlib/ml/surrogate.hpp b/src/AMSlib/ml/surrogate.hpp
index 997ad17a..e9e7751f 100644
--- a/src/AMSlib/ml/surrogate.hpp
+++ b/src/AMSlib/ml/surrogate.hpp
@@ -410,6 +410,8 @@ class SurrogateModel
else
_load(new_path, "cuda");
}
+
+ AMSResourceType getModelResource() const { return model_resource; }
};
template
diff --git a/src/AMSlib/ml/uq.hpp b/src/AMSlib/ml/uq.hpp
index dabe140f..6b6227bb 100644
--- a/src/AMSlib/ml/uq.hpp
+++ b/src/AMSlib/ml/uq.hpp
@@ -63,6 +63,8 @@ class UQ
if (uqPolicy == AMSUQPolicy::RandomUQ)
randomUQ = std::make_unique(resourceLocation, threshold);
+
+ DBG(UQ, "UQ Model is of type %d", uqPolicy)
}
PERFFASPECT()
@@ -73,21 +75,40 @@ class UQ
{
if ((uqPolicy == AMSUQPolicy::DeltaUQ_Mean) ||
(uqPolicy == AMSUQPolicy::DeltaUQ_Max)) {
+
+ auto &rm = ams::ResourceManager::getInstance();
+
CALIPER(CALI_MARK_BEGIN("DELTAUQ");)
const size_t ndims = outputs.size();
std::vector outputs_stdev(ndims);
// TODO: Enable device-side allocation and predicate calculation.
- auto &rm = ams::ResourceManager::getInstance();
for (int dim = 0; dim < ndims; ++dim)
outputs_stdev[dim] =
rm.allocate(totalElements, AMSResourceType::HOST);
CALIPER(CALI_MARK_BEGIN("SURROGATE");)
- DBG(Workflow,
- "Model exists, I am calling DeltaUQ surrogate (for all data)");
+ DBG(UQ,
+ "Model exists, I am calling DeltaUQ surrogate [%ld %ld] -> (mu:[%ld "
+ "%ld], std:[%ld %ld])",
+ totalElements,
+ inputs.size(),
+ totalElements,
+ outputs.size(),
+ totalElements,
+ inputs.size());
surrogate->evaluate(totalElements, inputs, outputs, outputs_stdev);
CALIPER(CALI_MARK_END("SURROGATE");)
+ // FIXME: We do something sub-optimal. We copy all the data from the GPU
+ // to the CPU and then we compute the predicate. Then we copy back the computed
+ // predicate to the device. We should avoid this unecessary back and forth.
+ bool *predicate = p_ml_acceptable;
+ if (surrogate->getModelResource() == AMSResourceType::DEVICE) {
+ predicate = rm.allocate(totalElements, AMSResourceType::HOST);
+ rm.copy(p_ml_acceptable, predicate);
+ }
+
+
if (uqPolicy == AMSUQPolicy::DeltaUQ_Mean) {
for (size_t i = 0; i < totalElements; ++i) {
// Use double for increased precision, range in the calculation
@@ -95,7 +116,7 @@ class UQ
for (size_t dim = 0; dim < ndims; ++dim)
mean += outputs_stdev[dim][i];
mean /= ndims;
- p_ml_acceptable[i] = (mean < threshold);
+ predicate[i] = (mean < threshold);
}
} else if (uqPolicy == AMSUQPolicy::DeltaUQ_Max) {
for (size_t i = 0; i < totalElements; ++i) {
@@ -106,12 +127,17 @@ class UQ
break;
}
- p_ml_acceptable[i] = is_acceptable;
+ predicate[i] = is_acceptable;
}
} else {
THROW(std::runtime_error, "Invalid UQ policy");
}
+ if (surrogate->getModelResource() == AMSResourceType::DEVICE) {
+ rm.copy(predicate, p_ml_acceptable);
+ rm.deallocate(predicate, AMSResourceType::HOST);
+ }
+
for (int dim = 0; dim < ndims; ++dim)
rm.deallocate(outputs_stdev[dim], AMSResourceType::HOST);
CALIPER(CALI_MARK_END("DELTAUQ");)
diff --git a/src/AMSlib/wf/basedb.hpp b/src/AMSlib/wf/basedb.hpp
index 34e49d19..16ce5be2 100644
--- a/src/AMSlib/wf/basedb.hpp
+++ b/src/AMSlib/wf/basedb.hpp
@@ -226,7 +226,7 @@ class csvDB final : public FileDB
/**
* @brief Return the DB enumerationt type (File, Redis etc)
*/
- AMSDBType dbType() { return AMSDBType::CSV; };
+ AMSDBType dbType() { return AMSDBType::DBCSV; };
/**
* @brief Takes an input and an output vector each holding 1-D vectors data, and
@@ -310,7 +310,7 @@ class hdf5DB final : public FileDB
*/
hid_t getDataSet(hid_t group,
std::string dName,
- const size_t Chunk = 32L * 1024L * 1024L)
+ const size_t Chunk = 256L * 1024L)
{
// Our datasets a.t.m are 1-D vectors
const int nDims = 1;
@@ -488,7 +488,7 @@ class hdf5DB final : public FileDB
/**
* @brief Return the DB enumerationt type (File, Redis etc)
*/
- AMSDBType dbType() { return AMSDBType::HDF5; };
+ AMSDBType dbType() { return AMSDBType::DBHDF5; };
/**
@@ -586,7 +586,7 @@ class RedisDB : public BaseDB
/**
* @brief Return the DB enumerationt type (File, Redis etc)
*/
- AMSDBType dbType() { return AMSDBType::REDIS; };
+ AMSDBType dbType() { return AMSDBType::DBREDIS; };
inline std::string info() { return _redis->info(); }
@@ -2339,7 +2339,7 @@ class RabbitMQDB final : public BaseDB
/**
* @brief Return the DB enumerationt type (File, Redis etc)
*/
- AMSDBType dbType() { return AMSDBType::RMQ; };
+ AMSDBType dbType() { return AMSDBType::DBRMQ; };
void close()
{
@@ -2425,18 +2425,18 @@ class DBManager
}
switch (dbType) {
- case AMSDBType::CSV:
+ case AMSDBType::DBCSV:
return std::make_shared>(dbPath, rId);
#ifdef __ENABLE_REDIS__
- case AMSDBType::REDIS:
+ case AMSDBType::DBREDIS:
return std::make_shared>(dbPath, rId);
#endif
#ifdef __ENABLE_HDF5__
- case AMSDBType::HDF5:
+ case AMSDBType::DBHDF5:
return std::make_shared>(dbPath, rId);
#endif
#ifdef __ENABLE_RMQ__
- case AMSDBType::RMQ:
+ case AMSDBType::DBRMQ:
return std::make_shared>(dbPath, rId);
#endif
default:
diff --git a/src/AMSlib/wf/cuda/utilities.cuh b/src/AMSlib/wf/cuda/utilities.cuh
index 3969688a..90fe26d4 100644
--- a/src/AMSlib/wf/cuda/utilities.cuh
+++ b/src/AMSlib/wf/cuda/utilities.cuh
@@ -16,6 +16,7 @@
#include
#include "wf/resource_manager.hpp"
+#include "wf/debug.h"
//#include
//#include
@@ -35,23 +36,10 @@ __device__ __inline__ int pow2i(int e) { return 1 << e; }
inline void __cudaSafeCall(cudaError err, const char* file, const int line)
{
-#ifdef CUDA_ERROR_CHECK
- if (cudaSuccess != err) {
- fprintf(stderr,
- "cudaSafeCall() failed at %s:%i : %s\n",
+ CFATAL(CUDA, (cudaSuccess != err), "cudaSafeCall() failed at %s:%i : %s\n",
file,
line,
cudaGetErrorString(err));
-
- fprintf(stdout,
- "cudaSafeCall() failed at %s:%i : %s\n",
- file,
- line,
- cudaGetErrorString(err));
- exit(-1);
- }
-#endif
-
return;
}
diff --git a/src/AMSlib/wf/debug.cpp b/src/AMSlib/wf/debug.cpp
index d2379609..86e4a4ef 100644
--- a/src/AMSlib/wf/debug.cpp
+++ b/src/AMSlib/wf/debug.cpp
@@ -5,6 +5,11 @@
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/
+#include "wf/debug.h"
+
+#if defined(__ENABLE_CUDA__) && defined(__ENABLE_TORCH__)
+#include
+#endif
#include
#include
@@ -40,3 +45,62 @@ void memUsage(double& vm_usage, double& resident_set)
vm_usage = vsize / 1024.0;
resident_set = rss * page_size;
}
+
+void dumpTorchDeviceStats()
+{
+#if defined(__ENABLE_CUDA__) && defined(__ENABLE_TORCH__)
+ c10::cuda::CUDACachingAllocator::emptyCache();
+ int curr_device = c10::cuda::current_device();
+ c10::cuda::CUDACachingAllocator::DeviceStats stats =
+ c10::cuda::CUDACachingAllocator::getDeviceStats(curr_device);
+
+ DBG(TorchDeviceStats,
+ "Current device according to torch has id : %d",
+ curr_device);
+
+ for (auto S : stats.allocated_bytes) {
+ DBG(TorchDeviceStats,
+ "Allocated Current: %g (MBytes) Peak: %g (MBytes) Allocated: %G "
+ "(MBytes) Freed: "
+ "%g (MBytes)",
+ (double)(S.current) / (1024.0 * 1024.0),
+ (double)(S.peak) / (1024.0 * 1024.0),
+ (double)(S.allocated) / (1024.0 * 1024.0),
+ (double)(S.freed) / (1024.0 * 1024.0));
+ }
+
+ for (auto S : stats.reserved_bytes) {
+ DBG(TorchDeviceStats,
+ "Reserved Current: %g (MBytes) Peak: %g (MBytes) Allocated: %G "
+ "(MBytes) Freed: "
+ "%g (MBytes)",
+ (double)(S.current) / (1024.0 * 1024.0),
+ (double)(S.peak) / (1024.0 * 1024.0),
+ (double)(S.allocated) / (1024.0 * 1024.0),
+ (double)(S.freed) / (1024.0 * 1024.0));
+ }
+
+ for (auto S : stats.active_bytes) {
+ DBG(TorchDeviceStats,
+ "Active Current: %g (MBytes) Peak: %g (MBytes) Allocated: %G "
+ "(MBytes) Freed: "
+ "%g (MBytes)",
+ (double)(S.current) / (1024.0 * 1024.0),
+ (double)(S.peak) / (1024.0 * 1024.0),
+ (double)(S.allocated) / (1024.0 * 1024.0),
+ (double)(S.freed) / (1024.0 * 1024.0));
+ }
+
+ for (auto S : stats.inactive_split_bytes) {
+ DBG(TorchDeviceStats,
+ "Inactive Split Bytes Current: %g (MBytes) Peak: %g (MBytes) "
+ "Allocated: %G "
+ "(MBytes) Freed: "
+ "%g (MBytes)",
+ (double)(S.current) / (1024.0 * 1024.0),
+ (double)(S.peak) / (1024.0 * 1024.0),
+ (double)(S.allocated) / (1024.0 * 1024.0),
+ (double)(S.freed) / (1024.0 * 1024.0));
+ }
+#endif
+}
diff --git a/src/AMSlib/wf/debug.h b/src/AMSlib/wf/debug.h
index 4927c32e..1bcbe99b 100644
--- a/src/AMSlib/wf/debug.h
+++ b/src/AMSlib/wf/debug.h
@@ -31,6 +31,8 @@ enum AMSVerbosity {
void memUsage(double& vm_usage, double& resident_set);
+void dumpTorchDeviceStats();
+void deviceMemoryInfo(size_t *free, size_t *total);
inline std::atomic& getInfoLevelInternal()
{
@@ -80,7 +82,12 @@ inline uint32_t getVerbosityLevel()
#define FATAL(id, ...) CFATAL(id, true, __VA_ARGS__)
-#ifdef LIBAMS_VERBOSE
+#define THROW(exception, msg) \
+ throw exception(std::string(__FILE__) + ":" + std::to_string(__LINE__) + \
+ " " + msg)
+
+
+#ifdef AMS_DEBUG
#define CWARNING(id, condition, ...) \
AMSPRINT(id, condition, AMSVerbosity::AMSWARNING, YEL, __VA_ARGS__)
@@ -97,35 +104,54 @@ inline uint32_t getVerbosityLevel()
#define DBG(id, ...) CDEBUG(id, true, __VA_ARGS__)
-#define REPORT_MEM_USAGE(id, phase) \
- do { \
- double vm, rs; \
- size_t watermark, current_size, actual_size; \
- auto& rm = ams::ResourceManager::getInstance(); \
- memUsage(vm, rs); \
- DBG(id, "Memory usage at %s is VM:%g RS:%g", phase, vm, rs); \
- \
- for (int i = 0; i < AMSResourceType::RSEND; i++) { \
- if (rm.isActive((AMSResourceType)i)) { \
- rm.getAllocatorStats((AMSResourceType)i, \
- watermark, \
- current_size, \
- actual_size); \
- DBG(id, \
- "Allocator: %s HWM:%lu CS:%lu AS:%lu) ", \
- rm.getAllocatorName((AMSResourceType)i) \
- .c_str(), \
- watermark, \
- current_size, \
- actual_size); \
- } \
- } \
+#define REPORT_MEM_USAGE(id, phase) \
+ do { \
+ double vm, rs; \
+ size_t watermark, current_size, actual_size; \
+ auto& rm = ams::ResourceManager::getInstance(); \
+ memUsage(vm, rs); \
+ DBG(id, "Memory usage at %s is VM:%g RS:%g", phase, vm, rs); \
+ \
+ for (int i = 0; i < AMSResourceType::RSEND; i++) { \
+ if (rm.isActive((AMSResourceType)i)) { \
+ rm.getAllocatorStats((AMSResourceType)i, \
+ watermark, \
+ current_size, \
+ actual_size); \
+ DBG(id, \
+ "Allocator (in MBytes): %s HWM:%g CS:%g AS:%g) ", \
+ rm.getAllocatorName((AMSResourceType)i).c_str(), \
+ (double)(watermark) / (1024.0 * 1024.0), \
+ (double)(current_size) / (1024.0 * 1024.0), \
+ (double)(actual_size) / (1024.0 * 1024.0)); \
+ } \
+ } \
+ dumpTorchDeviceStats(); \
+ size_t free, total; \
+ deviceMemoryInfo(&free, &total); \
+ DBG(id, \
+ "Device Memory Usage (cuda-driver) Used: %g MBytes, Free: %g MBytes", \
+ ((double)(total - free)) / (1024.0 * 1024.0), \
+ (double)(free) / (1024.0 * 1024.0)); \
} while (0);
-#define THROW(exception, msg) \
- throw exception(std::string(__FILE__) + ":" + std::to_string(__LINE__) + \
- " " + msg)
-#else // LIBAMS_VERBOSE is disabled
+#ifdef __ENABLE_CUDA__
+// NOTE: Regardless of condition we synchronize. We only emit a message based on condition.
+#define _CAMSDebugDeviceSync(id, condition, fn, ln, ...) \
+ do{ \
+ AMSDeviceSync(fn, ln); \
+ CDEBUG(id, condition, __VA_ARGS__) \
+ }while(0);
+
+#define CAMSDebugDeviceSync(id, condition, ...) _CAMSDebugDeviceSync(id, condition, __FILE__, __LINE__, __VA_ARGS__)
+#define AMSDebugDeviceSync(id, ...) _CAMSDebugDeviceSync(id, true, __FILE__, __LINE__, __VA_ARGS__)
+#else
+#define CAMSDebugDeviceSync(id, condition, ...)
+#define AMSDebugDeviceSync(id, ...)
+#endif
+
+
+#else // LIBAMS_DEBUG is disabled
#define CWARNING(id, condition, ...)
#define WARNING(id, ...)
@@ -138,7 +164,11 @@ inline uint32_t getVerbosityLevel()
#define DBG(id, ...)
+#define REPORT_MEM_USAGE(id, phase) \
+
+#define CAMSDebugDeviceSync(id, condition, ...)
+#define AMSDebugDeviceSync(id, ...)
-#endif // LIBAMS_VERBOSE
+#endif // AMS_DEBUG
-#endif // _OMPTARGET_DEBUG_H
+#endif // __AMS_DEBUG__
diff --git a/src/AMSlib/wf/device.hpp b/src/AMSlib/wf/device.hpp
index 73b784c0..e1c6cc10 100644
--- a/src/AMSlib/wf/device.hpp
+++ b/src/AMSlib/wf/device.hpp
@@ -153,6 +153,13 @@ void deviceCheckErrors(const char *file, const int line)
}
+void deviceMemoryInfo(size_t *free, size_t *total)
+{
+#ifdef __ENABLE_CUDA__
+ cudaMemGetInfo(free, total);
+#endif
+}
+
#ifdef __ENABLE_CUDA__
#include
@@ -184,7 +191,6 @@ __global__ void random_uq_device(int seed,
uq_flags[id] = (x <= acceptable_error);
}
-
#include
PERFFASPECT()
inline void DtoDMemcpy(void *dest, void *src, size_t nBytes)
@@ -209,6 +215,11 @@ inline void DtoHMemcpy(void *dest, void *src, size_t nBytes)
{
cudaMemcpy(dest, src, nBytes, cudaMemcpyDeviceToHost);
}
+
+inline void AMSDeviceSync(const char* file, const int line){
+ __cudaSafeCall(cudaDeviceSynchronize(), file, line);
+}
+
#else
PERFFASPECT()
inline void DtoDMemcpy(void *dest, void *src, size_t nBytes)
@@ -236,6 +247,10 @@ inline void DtoHMemcpy(void *dest, void *src, size_t nBytes)
std::cerr << "DtoH Memcpy Not Enabled" << std::endl;
exit(-1);
}
+
+inline void AMSDeviceSync(const char* file, const int line){
+ std::cerr << "GPU Not enabled" << std::endl;
+}
#endif
#endif
diff --git a/src/AMSlib/wf/redist_load.hpp b/src/AMSlib/wf/redist_load.hpp
index bbedf3b5..3f3d05f8 100644
--- a/src/AMSlib/wf/redist_load.hpp
+++ b/src/AMSlib/wf/redist_load.hpp
@@ -103,77 +103,10 @@ class AMSLoadBalancer
/** @brief The memory location of the data (GPU (DEVICE), CPU (HOST) ) */
AMSResourceType resource;
-private:
- /** @brief Computes the number of balanced elements each process will gather and initializes
- * memory structures.
- * @param[in] numIn The number of input dimensions
- * @param[in] numOut The number of input dimensions
- * @param[in] resource The resource type to allocate data in.
-
- * @details The function computes the total number of elements every rank will need to balance.
- * It initializes the 'dataElements', 'displs' on the root node and the localLoad, balancedLoad
- * across all ranks.
- */
- void init(int numIn, int numOut, AMSResourceType resource)
- {
- auto& rm = ams::ResourceManager::getInstance();
- // We need to store information
- if (rId == root) {
- dataElements =
- rm.allocate(worldSize, AMSResourceType::HOST);
- displs = rm.allocate(worldSize + 1,
- AMSResourceType::HOST);
- }
-
- // Gather the the number of items from each rank
- int rc = MPI_Gather(reinterpret_cast(&localLoad),
- 1,
- MPI_INT,
- reinterpret_cast(dataElements),
- 1,
- MPI_INT,
- root,
- Comm);
- CFATAL(LoadBalance, rc != MPI_SUCCESS, "Cannot gather per rank sizes")
-
- // Populate displacement array
- if (rId == 0) {
- globalLoad = 0;
- displs[0] = static_cast(0);
- for (size_t i = 0ul; i < worldSize; ++i) {
- displs[i + 1] = dataElements[i] + displs[i];
- }
- globalLoad = displs[worldSize];
- }
-
- balancedLoad = computeBalanceLoad();
-
- if (rId == root) {
- balancedElements =
- rm.ResourceManager::allocate(worldSize, AMSResourceType::HOST);
- balancedDispls =
- rm.ResourceManager::allocate(worldSize, AMSResourceType::HOST);
- for (int i = 0; i < worldSize; i++) {
- balancedElements[i] = (globalLoad / worldSize) +
- static_cast(i < (globalLoad % worldSize));
- if (i != 0)
- balancedDispls[i] = balancedElements[i] + balancedDispls[i - 1];
- else
- balancedDispls[i] = 0;
- }
- }
-
- for (int i = 0; i < numIn; i++) {
- distInputs.push_back(
- rm.allocate(balancedLoad, resource));
- }
-
- for (int i = 0; i < numOut; i++) {
- distOutputs.push_back(
- rm.allocate(balancedLoad, resource));
- }
- }
+ /** @brief Check if load balancer is initialized **/
+ bool _initialized;
+private:
/** @brief Computes the number of elements every rank will receive after balancing.
* @returns the number of elements computed by this rank.
**/
@@ -266,10 +199,11 @@ class AMSLoadBalancer
AMSResourceType resource)
{
FPTypeValue *temp_data;
- auto& rm = ams::ResourceManager::getInstance();
+ auto &rm = ams::ResourceManager::getInstance();
if (rId == root) {
- temp_data = rm.ResourceManager::allocate(globalLoad, resource);
+ temp_data =
+ rm.ResourceManager::allocate(globalLoad, resource);
}
for (int i = 0; i < src.size(); i++) {
@@ -304,17 +238,8 @@ class AMSLoadBalancer
* @param[in] worldSize The total number of ranks in respect to the Comm communicator.
* @param[in] localLoad The number of elements this rank has to compute originally (before load balance).
* @param[in] Comm The MPI communicator.
- * @param[in] numIn The number of input vectors to be balanced.
- * @param[in] numOut The number of output vectors to be balanced.
- * @param[in] resource The location of data allocations (CPU|GPU).
- */
- AMSLoadBalancer(int rId,
- int worldSize,
- int localLoad,
- MPI_Comm comm,
- int numIn,
- int numOut,
- AMSResourceType resource)
+ */
+ AMSLoadBalancer(int rId, int worldSize, int localLoad, MPI_Comm comm)
: rId(rId),
worldSize(worldSize),
localLoad(localLoad),
@@ -324,23 +249,27 @@ class AMSLoadBalancer
dataElements(nullptr),
balancedElements(nullptr),
balancedDispls(nullptr),
- resource(resource)
+ resource(resource),
+ _initialized(false)
{
- init(numIn, numOut, resource);
}
/** @brief deallocates all objects of this load balancing transcation */
~AMSLoadBalancer()
{
- auto& rm = ams::ResourceManager::getInstance();
- CINFO(LoadBalance, root==rId, "Total data %d Data per rank %d", globalLoad, balancedLoad);
+ if (!_initialized) return;
+
+ auto &rm = ams::ResourceManager::getInstance();
+ CINFO(LoadBalance,
+ root == rId,
+ "Total data %d Data per rank %d",
+ globalLoad,
+ balancedLoad);
if (displs) rm.deallocate(displs, AMSResourceType::HOST);
- if (dataElements)
- rm.deallocate(dataElements, AMSResourceType::HOST);
+ if (dataElements) rm.deallocate(dataElements, AMSResourceType::HOST);
if (balancedElements)
rm.deallocate(balancedElements, AMSResourceType::HOST);
- if (balancedDispls)
- rm.deallocate(balancedDispls, AMSResourceType::HOST);
+ if (balancedDispls) rm.deallocate(balancedDispls, AMSResourceType::HOST);
for (int i = 0; i < distOutputs.size(); i++)
rm.deallocate(distOutputs[i], resource);
@@ -350,6 +279,73 @@ class AMSLoadBalancer
}
};
+ /** @brief Computes the number of balanced elements each process will gather and initializes
+ * memory structures.
+ * @param[in] numIn The number of input dimensions
+ * @param[in] numOut The number of input dimensions
+ * @param[in] resource The resource type to allocate data in.
+
+ * @details The function computes the total number of elements every rank will need to balance.
+ * It initializes the 'dataElements', 'displs' on the root node and the localLoad, balancedLoad
+ * across all ranks.
+ */
+ void init(int numIn, int numOut, AMSResourceType resource)
+ {
+ auto &rm = ams::ResourceManager::getInstance();
+ // We need to store information
+ if (rId == root) {
+ dataElements = rm.allocate(worldSize, AMSResourceType::HOST);
+ displs = rm.allocate(worldSize + 1, AMSResourceType::HOST);
+ }
+
+ // Gather the the number of items from each rank
+ int rc = MPI_Gather(reinterpret_cast(&localLoad),
+ 1,
+ MPI_INT,
+ reinterpret_cast(dataElements),
+ 1,
+ MPI_INT,
+ root,
+ Comm);
+ CFATAL(LoadBalance, rc != MPI_SUCCESS, "Cannot gather per rank sizes")
+
+ // Populate displacement array
+ if (rId == 0) {
+ globalLoad = 0;
+ displs[0] = static_cast(0);
+ for (size_t i = 0ul; i < worldSize; ++i) {
+ displs[i + 1] = dataElements[i] + displs[i];
+ }
+ globalLoad = displs[worldSize];
+ }
+
+ balancedLoad = computeBalanceLoad();
+
+ if (rId == root) {
+ balancedElements =
+ rm.ResourceManager::allocate(worldSize, AMSResourceType::HOST);
+ balancedDispls =
+ rm.ResourceManager::allocate(worldSize, AMSResourceType::HOST);
+ for (int i = 0; i < worldSize; i++) {
+ balancedElements[i] = (globalLoad / worldSize) +
+ static_cast(i < (globalLoad % worldSize));
+ if (i != 0)
+ balancedDispls[i] = balancedElements[i] + balancedDispls[i - 1];
+ else
+ balancedDispls[i] = 0;
+ }
+ }
+
+ for (int i = 0; i < numIn; i++) {
+ distInputs.push_back(rm.allocate(balancedLoad, resource));
+ }
+
+ for (int i = 0; i < numOut; i++) {
+ distOutputs.push_back(rm.allocate(balancedLoad, resource));
+ }
+ _initialized = true;
+ }
+
/**
* @brief Reverse load balance in respect to the output vectors.
* @param[out] outputs The vector to store all the output values gathered from their compute (remote) ranks.
diff --git a/src/AMSlib/wf/workflow.hpp b/src/AMSlib/wf/workflow.hpp
index 2f1c8656..6003f53d 100644
--- a/src/AMSlib/wf/workflow.hpp
+++ b/src/AMSlib/wf/workflow.hpp
@@ -63,7 +63,7 @@ class AMSWorkflow
std::shared_ptr> DB;
/** @brief The type of the database we will use (HDF5, CSV, etc) */
- AMSDBType dbType = AMSDBType::None;
+ AMSDBType dbType = AMSDBType::DDDBNone;
/** @brief The process id. For MPI runs this is the rank */
const int rId;
@@ -141,7 +141,6 @@ class AMSWorkflow
DB->store(actualElems, hInputs, hOutputs);
}
rm.deallocate(pPtr, AMSResourceType::PINNED);
-
return;
}
@@ -149,7 +148,7 @@ class AMSWorkflow
AMSWorkflow()
: AppCall(nullptr),
DB(nullptr),
- dbType(AMSDBType::None),
+ dbType(AMSDBType::DDDBNone),
appDataLoc(AMSResourceType::HOST),
ePolicy(AMSExecPolicy::UBALANCED)
{
@@ -251,6 +250,7 @@ class AMSWorkflow
{
CALIPER(CALI_MARK_BEGIN("AMSEvaluate");)
+ CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-External Errors")
CDEBUG(Workflow,
rId == 0,
"Entering Evaluate "
@@ -276,13 +276,17 @@ class AMSWorkflow
totalElements,
reinterpret_cast(origInputs.data()),
reinterpret_cast(origOutputs.data()));
+
+ CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-AppCall Errors")
CALIPER(CALI_MARK_END("PHYSICS MODULE");)
if (DB) {
CALIPER(CALI_MARK_BEGIN("DBSTORE");)
Store(totalElements, tmpIn, origOutputs);
+ CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-Store Errors")
CALIPER(CALI_MARK_END("DBSTORE");)
}
CALIPER(CALI_MARK_END("AMSEvaluate");)
+ CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-Evaluate Errors")
return;
}
@@ -300,6 +304,7 @@ class AMSWorkflow
CALIPER(CALI_MARK_BEGIN("UQ_MODULE");)
UQModel->evaluate(totalElements, origInputs, origOutputs, p_ml_acceptable);
CALIPER(CALI_MARK_END("UQ_MODULE");)
+ CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-UQ/Surrogate Errors")
DBG(Workflow, "Computed Predicates")
@@ -324,6 +329,7 @@ class AMSWorkflow
const long packedElements = data_handler::pack(
appDataLoc, predicate, totalElements, origInputs, packedInputs);
CALIPER(CALI_MARK_END("PACK");)
+ CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-Pack Errors")
// Pointer values which store output data values
// to be computed using the eos function.
@@ -341,8 +347,9 @@ class AMSWorkflow
#ifdef __ENABLE_MPI__
CALIPER(CALI_MARK_BEGIN("LOAD BALANCE MODULE");)
AMSLoadBalancer lBalancer(
- rId, wSize, packedElements, Comm, inputDim, outputDim, appDataLoc);
+ rId, wSize, packedElements, Comm);
if (ePolicy == AMSExecPolicy::BALANCED && Comm) {
+ lBalancer.init(inputDim, outputDim, appDataLoc);
lBalancer.scatterInputs(packedInputs, appDataLoc);
iPtr = reinterpret_cast(lBalancer.inputs());
oPtr = reinterpret_cast(lBalancer.outputs());
@@ -351,12 +358,14 @@ class AMSWorkflow
CALIPER(CALI_MARK_END("LOAD BALANCE MODULE");)
#endif
+
// ---- 3b: call the physics module and store in the data base
if (packedElements > 0) {
CALIPER(CALI_MARK_BEGIN("PHYSICS MODULE");)
AppCall(probDescr, lbElements, iPtr, oPtr);
CALIPER(CALI_MARK_END("PHYSICS MODULE");)
}
+ CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-AppCall Errors")
#ifdef __ENABLE_MPI__
CALIPER(CALI_MARK_BEGIN("LOAD BALANCE MODULE");)
@@ -372,6 +381,7 @@ class AMSWorkflow
data_handler::unpack(
appDataLoc, predicate, totalElements, packedOutputs, origOutputs);
CALIPER(CALI_MARK_END("UNPACK");)
+ CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-UnPack Errors")
DBG(Workflow, "Finished physics evaluation")
@@ -383,6 +393,7 @@ class AMSWorkflow
Store(packedElements, packedInputs, packedOutputs);
CALIPER(CALI_MARK_END("DBSTORE");)
}
+ CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-Store Errors")
// -----------------------------------------------------------------
// Deallocate temporal data
@@ -405,6 +416,7 @@ class AMSWorkflow
REPORT_MEM_USAGE(Workflow, "End")
CALIPER(CALI_MARK_END("AMSEvaluate");)
+ CAMSDebugDeviceSync(Workflow, rId == 0, "DeviceSyncrhonize, No-Evaluate Errors")
}
};
diff --git a/tests/AMSlib/CMakeLists.txt b/tests/AMSlib/CMakeLists.txt
index 5385e079..e9290f3a 100644
--- a/tests/AMSlib/CMakeLists.txt
+++ b/tests/AMSlib/CMakeLists.txt
@@ -5,9 +5,9 @@
function (BUILD_TEST exe source)
add_executable(${exe} ${source})
- target_include_directories(${exe} PRIVATE "${PROJECT_SOURCE_DIR}/src/AMSlib/" umpire ${caliper_INCLUDE_DIR} ${MPI_INCLUDE_PATH})
+ target_include_directories(${exe} PRIVATE "${PROJECT_SOURCE_DIR}/src/AMSlib/" umpire ${AMS_APP_INCLUDES} ${caliper_INCLUDE_DIR} ${MPI_INCLUDE_PATH})
target_link_directories(${exe} PRIVATE ${AMS_APP_LIB_DIRS})
- target_link_libraries(${exe} PRIVATE AMS ${AMS_APP_LIBRARIES})
+ target_link_libraries(${exe} PRIVATE AMS ${AMS_APP_LIBRARIES} ${AMS_TORCH_LIBRARY})
target_compile_definitions(${exe} PRIVATE ${AMS_APP_DEFINES})
if (WITH_CUDA)
@@ -39,6 +39,19 @@ if (WITH_TORCH)
add_test(NAME AMSExampleSingleRandomUQ::HOST COMMAND ams_example --precision single --uqtype random -S ${CMAKE_CURRENT_SOURCE_DIR}/debug_model.pt -e 100)
add_test(NAME AMSExampleDoubleRandomUQ::HOST COMMAND ams_example --precision double --uqtype random -S ${CMAKE_CURRENT_SOURCE_DIR}/debug_model.pt -e 100)
+ # UQ Tests
+ BUILD_TEST(ams_delta_uq_test ams_uq_model.cpp)
+ if (WITH_TORCH)
+ add_test(NAME AMSDeltaUQDoubleMean::HOST COMMAND ams_delta_uq_test 0 ${CMAKE_CURRENT_SOURCE_DIR}/torch.duq.cuda "double" 8 9 3 0.0)
+ add_test(NAME AMSDeltaUQDoubleMax::HOST COMMAND ams_delta_uq_test 0 ${CMAKE_CURRENT_SOURCE_DIR}/torch.duq.cuda "double" 8 9 4 0.0)
+
+ if (WITH_CUDA)
+ add_test(NAME AMSDeltaUQDoubleMean::DEVICE COMMAND ams_delta_uq_test 1 ${CMAKE_CURRENT_SOURCE_DIR}/torch.duq.cuda "double" 8 9 3 0.0)
+ add_test(NAME AMSDeltaUQDoubleMax::DEVICE COMMAND ams_delta_uq_test 1 ${CMAKE_CURRENT_SOURCE_DIR}/torch.duq.cuda "double" 8 9 4 0.0)
+ endif()
+ endif()
+ #TODO Add tests with cpu model
+
BUILD_TEST(ams_update_model ams_update_model.cpp)
ADDTEST(ams_update_model AMSUpdateModelDouble "double" ${CMAKE_CURRENT_SOURCE_DIR}/ConstantZeroModel_cpu.pt ${CMAKE_CURRENT_SOURCE_DIR}/ConstantOneModel_cpu.pt)
endif()
diff --git a/tests/AMSlib/ams_uq_model.cpp b/tests/AMSlib/ams_uq_model.cpp
new file mode 100644
index 00000000..4a7262c8
--- /dev/null
+++ b/tests/AMSlib/ams_uq_model.cpp
@@ -0,0 +1,77 @@
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#define SIZE (32L * 1024L + 3L)
+
+template
+void model(UQ &model,
+ AMSResourceType resource,
+ int num_inputs,
+ int num_outputs)
+{
+ std::vector inputs;
+ std::vector outputs;
+ auto &ams_rm = ams::ResourceManager::getInstance();
+
+ for (int i = 0; i < num_inputs; i++)
+ inputs.push_back(ams_rm.allocate(SIZE, resource));
+
+ for (int i = 0; i < num_outputs; i++)
+ outputs.push_back(ams_rm.allocate(SIZE, resource));
+
+ bool *predicates = ams_rm.allocate(SIZE, resource);
+
+ std::cout << "We are calling evaluate\n";
+ model.evaluate(SIZE, inputs, outputs, predicates);
+
+
+ for (int i = 0; i < num_inputs; i++)
+ ams_rm.deallocate(const_cast(inputs[i]), resource);
+
+ for (int i = 0; i < num_outputs; i++)
+ ams_rm.deallocate(outputs[i], resource);
+
+ ams_rm.deallocate(predicates, resource);
+}
+
+
+int main(int argc, char *argv[])
+{
+ using namespace ams;
+ auto &ams_rm = ResourceManager::getInstance();
+ int use_device = std::atoi(argv[1]);
+ char *model_path = argv[2];
+ char *data_type = argv[3];
+ int num_inputs = std::atoi(argv[4]);
+ int num_outputs = std::atoi(argv[5]);
+ const AMSUQPolicy uq_policy = static_cast(std::atoi(argv[6]));
+ float threshold = std::atof(argv[7]);
+
+ std::cout << "Executing on device " << use_device << "\n";
+
+ AMSResourceType resource = AMSResourceType::HOST;
+ if (use_device == 1) {
+ resource = AMSResourceType::DEVICE;
+ }
+
+ ams_rm.init();
+
+
+ if (std::strcmp("double", data_type) == 0) {
+ UQ UQModel(resource, uq_policy, nullptr, -1, model_path, threshold);
+ model(UQModel, resource, num_inputs, num_outputs);
+ } else if (std::strcmp("single", data_type) == 0) {
+ UQ UQModel(resource, uq_policy, nullptr, -1, model_path, threshold);
+ model(UQModel, resource, num_inputs, num_outputs);
+ }
+
+ return 0;
+}
diff --git a/tests/AMSlib/torch.duq b/tests/AMSlib/torch.duq
new file mode 100644
index 00000000..58ee24fc
Binary files /dev/null and b/tests/AMSlib/torch.duq differ
diff --git a/tests/AMSlib/torch.duq.cuda b/tests/AMSlib/torch.duq.cuda
new file mode 100644
index 00000000..1a80a5a1
Binary files /dev/null and b/tests/AMSlib/torch.duq.cuda differ